Merge remote-tracking branch 'aosp/upstream-master' into rebase_tf

Bug: 116318213
Test: mm
Test: NeuralNetworksTest_static
Change-Id: I249cedf6b76acb8d5ab82c67cacf82885355853d
diff --git a/CODEOWNERS b/CODEOWNERS
index 78f80c8..94cc865 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -2,6 +2,7 @@
 
 /tenosrflow/core/debug @caisq
 /tensorflow/core/platform/windows/ @mrry
+/tensorflow/core/platform/s3 @yongtang
 /tensorflow/go @asimshankar
 /tensorflow/java/ @asimshankar
 /tensorflow/python/debug @caisq
@@ -30,14 +31,16 @@
 /tensorflow/contrib/gan/ @joel-shor
 /tensorflow/contrib/graph_editor/ @purpledog
 # NEED OWNER: /tensorflow/contrib/grid_rnn/
+/tensorflow/contrib/hadoop @yongtang
 /tensorflow/contrib/hvx/ @satok16
 /tensorflow/contrib/integrate/ @shoyer
+/tensorflow/contrib/kafka @yongtang
 /tensorflow/contrib/kernel_methods/ @petrosmol
+/tensorflow/contrib/kinesis @yongtang
 /tensorflow/contrib/ios_examples/ @petewarden
 /tensorflow/contrib/labeled_tensor/ @shoyer
 /tensorflow/contrib/layers/ @fchollet @martinwicke
 /tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp
-/tensorflow/contrib/linalg/ @langmore
 /tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis
 /tensorflow/contrib/lookup/ @ysuematsu @andreasst
 /tensorflow/contrib/losses/ @alextp @ispirmustafa
diff --git a/RELEASE.md b/RELEASE.md
index bdc2379..20e1d92 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -1,9 +1,86 @@
+# Release 1.11.0
+
+## Major Features and Improvements
+
+* Nvidia GPU:
+  * Prebuilt binaries are now (as of TensorFlow 1.11) built against cuDNN 7.2 and TensorRT 4. See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support)
+* Google Cloud TPU:
+  * Experimental tf.data integration for Keras on Google Cloud TPUs.
+  * Experimental / preview support for eager execution on Google Cloud TPUs.
+* DistributionStrategy:
+  * Add multi-GPU DistributionStrategy support in tf.keras. Users can now use `fit`, `evaluate` and `predict` to distribute their model on multiple GPUs.
+  * Add multi-worker DistributionStrategy and standalone client support in Estimator. See [README] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute) for more details.
+* Add C, C++, and Python functions for querying kernels
+
+## Breaking Changes
+
+* Keras:
+  * The default values for tf.keras `RandomUniform`, `RandomNormal`, and `TruncatedNormal` initializers have been changed to match those in external Keras.
+  * Breaking change: `model.get_config()` on a Sequential model now returns a config dictionary (consistent with other Model instances) instead of a list of configs for the underlying layers.
+
+## Bug Fixes and Other Changes
+
+* C++:
+  * Changed the signature of SessionFactory::NewSession so that it can return a meaningful error message on failure.
+* tf.data:
+  * Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`. [tf.data] Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`.
+  * `tf.data.Dataset.list_files()` raises an exception at initialization time if the argument matches no files.
+  * Renamed BigTable class to BigtableTable for clarity
+  * Document use of the Cloud Bigtable API
+  * Adding `tf.contrib.data.reduce_dataset` which can be used to reduce a dataset to a single element.
+  * Generalization of `tf.contrib.data.sliding_window_batch`.
+* INC:
+  * Runtime improvements to triangular solve.
+* `tf.contrib`:
+  * Add an `implementation` argument to `tf.keras.layers.LocallyConnected2D` and `tf.keras.layers.LocallyConnected1D`. The new mode (`implementation=2`) performs forward pass as a single dense matrix multiplication, allowing dramatic speedups in certain scenarios (but worse performance in others - see docstring). The option also allows to use `padding=same`.
+  * Add documentation clarifying the differences between tf.fill and tf.constant.
+  * Add experimental IndexedDatasets.
+  * Add selective registration target using the lite proto runtime.
+  * Add simple Tensor and DataType classes to TensorFlow Lite Java
+  * Add support for bitcasting to/from uint32 and uint64.
+  * Added a subclass of Estimator that can be created from a SavedModel (SavedModelEstimator).
+  * Adds leaf index modes as an argument.
+  * Allow a different output shape from the input in tf.contrib.image.transform.
+  * Change the state_size order of the StackedRNNCell to be natural order. To keep the existing behavior, user can add reverse_state_order=True when constructing the StackedRNNCells.
+  * Deprecate self.test_session() in favor of self.session() or self.cached_session().
+  * Directly import tensor.proto.h (the transitive import will be removed from tensor.h soon)
+  * Estimator.train() now supports tf.contrib.summary.\* summaries out of the box; each call to .train() will now create a separate tfevents file rather than re-using a shared one.
+  * Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term should not end up in the accumulator.
+  * Fix toco compilation/execution on Windows
+  * GoogleZoneProvider class added to detect  which Google Cloud Engine zone tensorflow is running in.
+  * It is now safe to call any of the C API's TF_Delete\* functions on nullptr
+  * Log some errors on Android to logcat
+  * Match FakeQuant numerics in TFLite to improve accuracy of TFLite quantized inference models.
+  * Optional bucket location check for the GCS Filesystem.
+  * Performance enhancements for StringSplitOp & StringSplitV2Op.
+  * Performance improvements for regex replace operations.
+  * TFRecordWriter now raises an error if .write() fails.
+  * TPU: More helpful error messages in TPUClusterResolvers.
+  * The legacy_init_op argument to SavedModelBuilder methods for adding MetaGraphs has been deprecated. Please use the equivalent main_op argument instead. As part of this, we now explicitly check for a single main_op or legacy_init_op at the time of SavedModel building, whereas the check on main_op was previously only done at load time.
+  * The protocol used for Estimator training is now configurable in RunConfig.
+  * Triangular solve performance improvements.
+  * Unify RNN cell interface between TF and Keras. Add new get_initial_state() to Keras and TF RNN cell, which will use to replace the existing zero_state() method.
+  * Update initialization of variables in Keras.
+  * Updates to "constrained_optimization" in tensorflow/contrib.
+  * boosted trees: adding pruning mode
+  * tf.train.Checkpoint does not delete old checkpoints by default.
+  * tfdbg: Limit the total disk space occupied by dumped tensor data to 100 GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow adjustment of this upper limit.
+
+## Thanks to our Contributors
+
+This release contains contributions from many people at Google, as well as:
+
+Aapeli, adoda, Ag Ramesh, Amogh Mannekote, Andrew Gibiansky, Andy Craze, Anirudh Koul, Aurelien Geron, Avijit, Avijit-Nervana, Ben, Benjamin H. Myara, bhack, Brett Koonce, Cao Zongyan, cbockman, cheerss, Chikanaga Tomoyuki, Clayne Robison, cosine0, Cui Wei, Dan J, David, David Norman, Dmitry Klimenkov, Eliel Hojman, Florian Courtial, fo40225, formath, Geoffrey Irving, gracehoney, Grzegorz Pawelczak, Guoliang Hua, Guozhong Zhuang, Herman Zvonimir DošIlović, HuiyangFei, Jacker, Jan HüNnemeyer, Jason Taylor, Jason Zaman, Jesse, Jiang,Zhoulong, Jiawei Zhang, Jie, Joe Yearsley, Johannes Schmitz, Jon Perl, Jon Triebenbach, Jonathan, Jonathan Hseu, Jongmin Park, Justin Shenk, karl@kubx.ca, Kate Hodesdon, Kb Sriram, Keishi Hattori, Kenneth Blomqvist, Koan-Sin Tan, Li Liangbin, Li, Yiqiang, Loo Rong Jie, Madiyar, Mahmoud Abuzaina, Mark Ryan, Matt Dodge, mbhuiyan, melvinljy96, Miguel Mota, Nafis Sadat, Nathan Luehr, naurril, Nehal J Wani, Niall Moran, Niranjan Hasabnis, Nishidha Panpaliya, npow, olicht, Pei Zhang, Peng Wang (Simpeng), Peng Yu, Philipp Jund, Pradeep Banavara, Pratik Kalshetti, qwertWZ, Rakesh Chada, Randy West, Ray Kim, Rholais Lii, Robin Richtsfeld, Rodrigo Silveira, Ruizhi, Santosh Kumar, Seb Bro, Sergei Lebedev, sfujiwara, Shaba Abhiram, Shashi, SneakyFish5, Soila Kavulya, Stefan Dyulgerov, Steven Winston, Sunitha Kambhampati, Surry Shome, Taehoon Lee, Thor Johnsen, Tristan Rice, TShapinsky, tucan, tucan9389, Vicente Reyes, Vilmar-Hillow, Vitaly Lavrukhin, wangershi, weidan.kong, weidankong, Wen-Heng (Jack) Chung, William D. Irons, Wim Glenn, XFeiF, Yan Facai (颜发才), Yanbo Liang, Yong Tang, Yoshihiro Yamazaki, Yuan (Terry) Tang, Yuan, Man, zhaoyongke, ÁRon
+Ricardo Perez-Lopez, 张天启, 张晓飞
+
+
 # Release 1.10.1
 ## Bug Fixes and Other Changes
 
 * `tf.keras`:
   * Fixing keras on Cloud TPUs. No new binaries will be built for Windows.
 
+
 # Release 1.10.0
 
 ## Major Features And Improvements
@@ -17,7 +94,7 @@
 
 ## Breaking Changes
 
-* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) and [Install TensorFlow from Sources](https://www.tensorflow.org/install/install_sources#optional_install_tensorflow_for_gpu_prerequisites).
+* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [TensorFlow GPU support](https://www.tensorflow.org/install/gpu) and [Build TensorFlow from source](https://www.tensorflow.org/install/source).
 * Starting from TensorFlow 1.11, Windows builds will use Bazel. Therefore, we will drop official support for cmake.
 
 ## Bug Fixes and Other Changes
diff --git a/configure.py b/configure.py
index 52a5137..e9d162f 100644
--- a/configure.py
+++ b/configure.py
@@ -1572,6 +1572,9 @@
   if is_windows():
     set_windows_build_flags(environ_cp)
 
+  # Add a config option to build TensorFlow 2.0 API.
+  write_to_bazelrc('build:v2 --define=tf_api_version=2')
+
   if get_var(
       environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace',
       False,
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 386e009..c8e24e3 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -608,6 +608,47 @@
     ],
 )
 
+genrule(
+    name = "install_headers",
+    srcs = [
+        "//tensorflow/c:headers",
+        "//tensorflow/c/eager:headers",
+        "//tensorflow/cc:headers",
+        "//tensorflow/core:headers",
+    ],
+    outs = ["include"],
+    cmd = """
+    mkdir $@
+    for f in $(SRCS); do
+      d="$${f%/*}"
+      d="$${d#bazel-out*genfiles/}"
+      d="$${d#*external/eigen_archive/}"
+
+      if [[ $${d} == *local_config_* ]]; then
+        continue
+      fi
+
+      mkdir -p "$@/$${d}"
+      cp "$${f}" "$@/$${d}/"
+    done
+    """,
+    tags = ["manual"],
+    visibility = ["//visibility:public"],
+)
+
+genrule(
+    name = "root_init_gen",
+    srcs = select({
+        "api_version_2": [":tf_python_api_gen_v2"],
+        "//conditions:default": [":tf_python_api_gen_v1"],
+    }),
+    outs = ["__init__.py"],
+    cmd = select({
+        "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)",
+        "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)",
+    }),
+)
+
 gen_api_init_files(
     name = "tf_python_api_gen_v1",
     srcs = ["api_template.__init__.py"],
@@ -629,19 +670,6 @@
     root_init_template = "api_template.__init__.py",
 )
 
-genrule(
-    name = "root_init_gen",
-    srcs = select({
-        "api_version_2": [":tf_python_api_gen_v2"],
-        "//conditions:default": [":tf_python_api_gen_v1"],
-    }),
-    outs = ["__init__.py"],
-    cmd = select({
-        "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)",
-        "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)",
-    }),
-)
-
 py_library(
     name = "tensorflow_py",
     srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"],
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index c195c9e..3bcc62c 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -8705,3 +8705,53 @@
 
   return createTFEDequeue(ctx, TF_VARIANT, queue, status);
 }
+
+static void CheckOk(TF_Status* status) {
+  CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
+}
+
+void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
+  auto* status = TF_NewStatus();
+  TF_Tensor* t = TFE_TensorHandleResolve(handle, status);
+  CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+  tensorflow::Tensor dst;
+  TF_CHECK_OK(TF_TensorToTensor(t, &dst));
+  LOG(INFO) << dst.DebugString();
+
+  TF_DeleteTensor(t);
+  TF_DeleteStatus(status);
+}
+
+TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx) {
+  // Intentionally LOG into INFO below for ease of debugging.
+  VLOG(1) << "TFE_RunConstOp called";
+
+  auto* status = TF_NewStatus();
+  auto* op = TFE_NewOp(ctx, "Const", status);
+  CheckOk(status);
+  TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+
+  auto* tensor =
+      TF_AllocateTensor(TF_FLOAT, /*shape.data()*/ nullptr, /*shape.size()*/ 0,
+                        TF_DataTypeSize(TF_FLOAT) * 1);
+  auto* ptr = reinterpret_cast<char*>(TF_TensorData(tensor));
+  *reinterpret_cast<float*>(ptr) = 17.0;
+
+  TFE_OpSetAttrTensor(op, "value", tensor, status);
+  CheckOk(status);
+  TF_DeleteTensor(tensor);
+  VLOG(1) << "New op created";
+
+  TFE_TensorHandle* retval;
+  int num_retvals = 1;
+  TFE_Execute(op, &retval, &num_retvals, status);
+  CheckOk(status);
+  CHECK_EQ(num_retvals, 1);
+  VLOG(1) << "Op executed";
+
+  TFE_DeleteOp(op);
+  TF_DeleteStatus(status);
+
+  return retval;
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 522c91f..a3ca847 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -174,6 +174,15 @@
 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
     TF_Session* session, int tensor_id, TF_Status* status);
 
+// Prints `handle` in a human readable format to standard output for debugging.
+TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
+    TFE_TensorHandle* handle);
+
+// Returns a const scalar tensor.
+// Caller owns both the input and the output tensor handles.
+// TODO: Remove this API with hard-coded tensor computation.
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx);
+
 #ifdef __cplusplus
 } /* end extern "C" */
 #endif
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 349d9bc..0bf3d95 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -375,6 +375,17 @@
   return result;
 }
 
+int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
+  if (h == nullptr || h->handle == nullptr) {
+    status->status = tensorflow::errors::InvalidArgument(
+        "The passed in handle is a nullptr");
+    return -1;
+  }
+  tensorflow::int64 result;
+  status->status = h->handle->NumElements(&result);
+  return result;
+}
+
 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
                             TF_Status* status) {
   if (h == nullptr || h->handle == nullptr) {
@@ -567,6 +578,13 @@
   op->operation.MutableAttrs()->Set(attr_name, attr_value);
 }
 
+void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
+                         TF_Status* status) {
+  tensorflow::Tensor t;
+  status->status = TF_TensorToTensor(tensor, &t);
+  if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
+}
+
 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
                              const void* const* values, const size_t* lengths,
                              int num_values) {
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 337447e..6323f8a 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -163,6 +163,8 @@
 // This function will block till the operation that produces `h` has completed.
 TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h,
                                                   TF_Status* status);
+TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h,
+                                                          TF_Status* status);
 // This function will block till the operation that produces `h` has completed.
 TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
                                                   int dim_index,
@@ -311,6 +313,11 @@
                                                  const char* attr_name,
                                                  const TFE_Op* value);
 
+TF_CAPI_EXPORT extern void TFE_OpSetAttrTensor(TFE_Op* op,
+                                               const char* attr_name,
+                                               TF_Tensor* tensor,
+                                               TF_Status* status);
+
 TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op,
                                                    const char* attr_name,
                                                    const void* const* values,
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index ce038a4..41b5b8f 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -29,15 +29,8 @@
 namespace tensorflow {
 namespace eager {
 
-// Information about a tensor.
-struct TapeTensor {
-  int64 id;  // Expected to be unique in the lifetime of this process.
-  DataType dtype;
-  TensorShape shape;
-};
-
 // Represents an entry in the tape.
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
 struct OpTapeEntry {
   string op_type;
   std::vector<TapeTensor> output_tensor_info;
@@ -57,8 +50,8 @@
 using TensorTape = gtl::FlatMap<int64, int64>;
 
 // Map from operation-id to tape entry.
-template <typename BackwardFunction>
-using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
+template <typename BackwardFunction, typename TapeTensor>
+using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
 
 // Operations the tape needs to perform on tensors to do backpropagation. Named
 // "vspace" because a subset of these are related to a vector space, such as
@@ -79,7 +72,7 @@
 // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
 // specialization, which is blocked by quite a few things needing to loop back
 // into python now.
-template <typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
 class VSpace {
  public:
   virtual ~VSpace() {}
@@ -93,10 +86,10 @@
       gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
 
   // Returns a tensor of the right shape and dtype filled with zeros.
-  virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0;
+  virtual Gradient* Zeros(const TapeTensor& tensor) const = 0;
 
   // Returns a Tensor which is filled with ones and like the input.
-  virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0;
+  virtual Gradient* Ones(const TapeTensor& tensor) const = 0;
 
   // Calls the passed-in backward function.
   virtual Status CallBackwardFunction(
@@ -114,7 +107,7 @@
 
 // Traces the execution of operations, doing eager garbage collection, and
 // exporting a full trace so other code can do backpropagation. Not thread-safe.
-template <typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
 class GradientTape {
  public:
   // If `persistent` is true, GradientTape will not eagerly delete backward
@@ -134,7 +127,7 @@
   void Watch(int64 tensor_id);
 
   void RecordOperation(
-      const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+      const string& op_type, std::vector<TapeTensor>& output_tensors,
       gtl::ArraySlice<int64> input_tensor_id,
       gtl::ArraySlice<tensorflow::DataType> input_dtypes,
       BackwardFunction* backward_function,
@@ -146,17 +139,18 @@
   // once) and produces the gradient of the target tensors with respect to the
   // source tensors. The output gradients are used if not empty and not
   // null. The result is populated with one tensor per target element.
-  Status ComputeGradient(const VSpace<Gradient, BackwardFunction>& vspace,
-                         gtl::ArraySlice<int64> target_tensor_ids,
-                         gtl::ArraySlice<int64> source_tensor_id,
-                         gtl::ArraySlice<Gradient*> output_gradients,
-                         std::vector<Gradient*>* result);
+  Status ComputeGradient(
+      const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+      gtl::ArraySlice<int64> target_tensor_ids,
+      gtl::ArraySlice<int64> source_tensor_id,
+      gtl::ArraySlice<Gradient*> output_gradients,
+      std::vector<Gradient*>* result);
 
   bool IsPersistent() const { return persistent_; }
 
  private:
   TensorTape tensor_tape_;
-  OpTape<BackwardFunction> op_tape_;
+  OpTape<BackwardFunction, TapeTensor> op_tape_;
   int64 next_op_id_{0};
 
   // Map from tensor id to number of remaining usages (i.e. how many entries in
@@ -186,8 +180,8 @@
   }
 }
 
-template <typename Gradient, typename BackwardFunction>
-bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
     gtl::ArraySlice<int64> tensor_ids,
     gtl::ArraySlice<tensorflow::DataType> dtypes) {
   CHECK_EQ(tensor_ids.size(), dtypes.size());
@@ -201,14 +195,15 @@
   return false;
 }
 
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
+    int64 tensor_id) {
   tensor_tape_.emplace(tensor_id, -1);
 }
 
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::RecordOperation(
-    const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
+    const string& op_type, std::vector<TapeTensor>& output_tensors,
     gtl::ArraySlice<int64> input_tensor_id,
     gtl::ArraySlice<tensorflow::DataType> input_dtypes,
     BackwardFunction* backward_function,
@@ -229,16 +224,18 @@
   for (const TapeTensor& o : output_tensors) {
     // Note: the tensor can have already been watched and hence be in the tape,
     // so we cannot check that we're inserting it here.
-    tensor_tape_[o.id] = op_id;
-    tensor_usage_[o.id] = 1;
+    tensor_tape_[o.GetID()] = op_id;
+    tensor_usage_[o.GetID()] = 1;
     tensors.push_back(o);
   }
-  op_tape_[op_id] = OpTapeEntry<BackwardFunction>{
-      op_type, tensors, ids, backward_function, backward_function_deleter};
+  op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
+      op_type, std::move(tensors), ids, backward_function,
+      backward_function_deleter};
 }
 
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace(
+    int64 tensor_id) {
   auto it = tensor_usage_.find(tensor_id);
   if (it == tensor_usage_.end()) {
     return;
@@ -261,7 +258,7 @@
   auto op_it = op_tape_.find(op_id);
   CHECK(op_it != op_tape_.end());
   for (const auto& output : op_it->second.output_tensor_info) {
-    if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
+    if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) {
       // Found a usage for an output, so cannot delete the op.
       return;
     }
@@ -304,9 +301,9 @@
 
 namespace {
 
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
 struct BackpropInitialState {
-  OpTape<BackwardFunction> op_tape;
+  OpTape<BackwardFunction, TapeTensor> op_tape;
 
   // Map from tensor ID to how many references still exist for this tensor in
   // the tape.
@@ -322,17 +319,17 @@
 // If `persistent_tape` is false, op_tape is cleared and backwards functions
 // not needed for gradient computation are deleted. Backwards functions that
 // are needed, are copied and returned in BackpropInitialState.
-template <typename BackwardFunction>
-BackpropInitialState<BackwardFunction> PrepareBackprop(
+template <typename BackwardFunction, typename TapeTensor>
+BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
     gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
-    OpTape<BackwardFunction>* op_tape, const gtl::FlatSet<int64>& sources_set,
-    bool persistent_tape) {
+    OpTape<BackwardFunction, TapeTensor>* op_tape,
+    const gtl::FlatSet<int64>& sources_set, bool persistent_tape) {
   std::vector<int64> tensor_stack;
   tensor_stack.reserve(target.size());
   for (auto t : target) {
     tensor_stack.push_back(t);
   }
-  BackpropInitialState<BackwardFunction> result;
+  BackpropInitialState<BackwardFunction, TapeTensor> result;
   while (!tensor_stack.empty()) {
     int64 tensor_id = tensor_stack.back();
     tensor_stack.pop_back();
@@ -383,9 +380,9 @@
   return result;
 }
 
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
 std::vector<int64> InitialStack(
-    const OpTape<BackwardFunction>& op_tape,
+    const OpTape<BackwardFunction, TapeTensor>& op_tape,
     const gtl::FlatMap<int64, int64>& op_missing_tensor) {
   std::vector<int64> result;
   for (auto& op_entry : op_tape) {
@@ -396,13 +393,13 @@
   return result;
 }
 
-template <typename Gradient, typename BackwardFunction>
-Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
-                        gtl::ArraySlice<int64> target_tensor_ids,
-                        gtl::ArraySlice<Gradient*> output_gradients,
-                        const TensorTape& tensor_tape,
-                        const OpTape<BackwardFunction>& op_tape,
-                        gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+Status InitialGradients(
+    const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+    gtl::ArraySlice<int64> target_tensor_ids,
+    gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
+    const OpTape<BackwardFunction, TapeTensor>& op_tape,
+    gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
   for (int i = 0; i < target_tensor_ids.size(); ++i) {
     const int64 id = target_tensor_ids[i];
     if (output_gradients.empty() || output_gradients[i] == nullptr) {
@@ -416,11 +413,10 @@
         }
         bool found = false;
         for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
-          if (op_it->second.output_tensor_info[j].id == id) {
+          if (op_it->second.output_tensor_info[j].GetID() == id) {
             found = true;
             (*result)[id].push_back(
-                vspace.Ones(op_it->second.output_tensor_info[j].shape,
-                            op_it->second.output_tensor_info[j].dtype));
+                vspace.Ones(op_it->second.output_tensor_info[j]));
             break;
           }
         }
@@ -440,6 +436,18 @@
   return Status::OK();
 }
 
+// TODO(agarwal): use an automatic mechanism for handling None arguments to
+// gradient functions.
+//
+// Some gradient functions can accept None arguments for gradients. The
+// following maps the operation name to the indices at which the corresponding
+// gradient function can accept None values. e.g. FusedBatchNorm outputs 5
+// values and hence receives 5 gradient values during backprop. However the
+// gradient function uses only the first of those values and ignores the rest.
+// The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient
+// corresponding to index 0 is used, and the gradient values at indices 1-4 are
+// ignored (and hence can be None). The backprop algorithm can then leverage
+// this by not constructing zeros to pass for those indices.
 gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() {
   static auto* const m = new gtl::FlatMap<string, gtl::FlatSet<int>>({
       {"SoftmaxCrossEntropyWithLogits", {1}},
@@ -457,16 +465,16 @@
 constexpr int kMinAggregateCount = 4;
 constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
 
-template <typename Gradient, typename BackwardFunction>
-Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
-    const VSpace<Gradient, BackwardFunction>& vspace,
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
+    const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
     gtl::ArraySlice<int64> target_tensor_ids,
     gtl::ArraySlice<int64> source_tensor_ids,
     gtl::ArraySlice<Gradient*> output_gradients,
     std::vector<Gradient*>* result) {
   gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
                                   source_tensor_ids.end());
-  BackpropInitialState<BackwardFunction> state = PrepareBackprop(
+  BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
       target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
   std::vector<int64> op_stack =
       InitialStack(state.op_tape, state.op_missing_tensor);
@@ -510,7 +518,7 @@
     out_gradients.reserve(trace.output_tensor_info.size());
     bool any_gradient_nonzero = false;
     for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
-      const int64 id = trace.output_tensor_info[i].id;
+      const int64 id = trace.output_tensor_info[i].GetID();
       auto grad_it = gradients.find(id);
       if (grad_it == gradients.end()) {
         auto func_name_it =
@@ -519,9 +527,7 @@
             func_name_it->second.find(i) != func_name_it->second.end()) {
           out_gradients.push_back(nullptr);
         } else {
-          out_gradients.push_back(
-              vspace.Zeros(trace.output_tensor_info[i].shape,
-                           trace.output_tensor_info[i].dtype));
+          out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i]));
         }
       } else {
         any_gradient_nonzero = true;
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 8486b58..247236b 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -110,7 +110,7 @@
   session->extend_before_run = false;
 }
 
-std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
+std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
   Node* node = &output.oper->node;
   CppShapeInferenceResult::HandleData handle_data;
   handle_data.set_is_set(true);
@@ -135,9 +135,8 @@
   return result;
 }
 
-void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
-                                   const void* proto, size_t proto_len,
-                                   TF_Status* status) {
+void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
+                           size_t proto_len, TF_Status* status) {
   tensorflow::CppShapeInferenceResult::HandleData handle_data;
   if (!handle_data.ParseFromArray(proto, proto_len)) {
     status->status = tensorflow::errors::InvalidArgument(
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index 4bcb5bd..5cce840 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -54,16 +54,17 @@
 void ExtendSession(TF_Session* session, TF_Status* status);
 
 // Returns the serialized CppShapeInferenceResult::HandleData proto for
-// `output` if its a resource tensor, or otherwise returns the empty string.
-std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
+// `output` if its a resource or variant tensor, or otherwise returns the empty
+// string.
+std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
 
 // Sets `output` based on `proto`, which should be a serialized
-// CppShapeInferenceResult::HandleData proto.
+// CppShapeInferenceResult::HandleData proto. `output` should be a resource
+// or variant tensor.
 // NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string
 // because I couldn't get SWIG to work otherwise.
-void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
-                                   const void* proto, size_t proto_len,
-                                   TF_Status* status);
+void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
+                           size_t proto_len, TF_Status* status);
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_C_PYTHON_API_H_
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index f56521da..b587e63 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -10,11 +10,12 @@
 
 load(
     "//tensorflow:tensorflow.bzl",
-    "tf_cc_test",
+    "cc_library_with_android_deps",
     "tf_cc_binary",
+    "tf_cc_test",
     "tf_copts",
     "tf_gen_op_wrappers_cc",
-    "cc_library_with_android_deps",
+    "transitive_hdrs",
 )
 
 cc_library(
@@ -716,3 +717,26 @@
         "//tensorflow/core:testlib",
     ],
 )
+
+transitive_hdrs(
+    name = "headers",
+    visibility = ["//tensorflow:__subpackages__"],
+    deps = [
+        ":cc_ops",
+        ":client_session",
+        ":coordinator",
+        ":gradient_checker",
+        ":gradients",
+        ":ops",
+        ":queue_runner",
+        ":remote_fused_graph_ops",
+        ":scope",
+        "//tensorflow/cc/profiler",
+        "//tensorflow/cc/saved_model:constants",
+        "//tensorflow/cc/saved_model:loader",
+        "//tensorflow/cc/saved_model:reader",
+        "//tensorflow/cc/saved_model:signature_constants",
+        "//tensorflow/cc/saved_model:tag_constants",
+        "//tensorflow/cc/tools:freeze_saved_model",
+    ],
+)
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 7a0932d..10fa33a 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -25,6 +25,7 @@
         ":test_graph_tfmatmul_test",
         ":test_graph_tfmatmulandadd_test",
         ":test_graph_tfsplits_test",
+        ":test_graph_tftop_k_test",
         ":tfcompile_test",
     ],
 )
@@ -42,6 +43,7 @@
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
+        "//tensorflow/python:nn_ops",
         "//tensorflow/python:platform",
         "//tensorflow/python:session",
         "//tensorflow/python:training",
@@ -66,6 +68,7 @@
         "test_graph_tfmatmul.pb",
         "test_graph_tfmatmulandadd.pb",
         "test_graph_tfsplits.pb",
+        "test_graph_tftop_k.pb",
     ],
     # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
     # GPUs which might be present.  This is important because builds may run
@@ -208,6 +211,17 @@
     ],
 )
 
+tf_library(
+    name = "test_graph_tftop_k",
+    testonly = 1,
+    config = "test_graph_tftop_k.config.pbtxt",
+    cpp_class = "TopKComp",
+    graph = "test_graph_tftop_k.pb",
+    tags = [
+        "manual",
+    ],
+)
+
 tf_cc_test(
     name = "tfcompile_test",
     srcs = ["tfcompile_test.cc"],
@@ -226,6 +240,7 @@
         ":test_graph_tfmatmulandadd",
         ":test_graph_tfmatmulandadd_with_profiling",
         ":test_graph_tfsplits",
+        ":test_graph_tftop_k",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla:xla_data_proto",
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
index 9ec7df1..de135d7 100644
--- a/tensorflow/compiler/aot/tests/make_test_graphs.py
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -31,6 +31,7 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import app
 from tensorflow.python.training import saver as saver_lib
@@ -142,6 +143,12 @@
   array_ops.identity(y, name='result')
 
 
+def tftop_k(_):
+  x = array_ops.placeholder(dtypes.int32, shape=[5], name='x')
+  output = nn_ops.top_k(x, 2, name='values')
+  array_ops.identity(output[1], name='indices')
+
+
 def write_graph(build_graph, out_dir):
   """Build a graph using build_graph and write it out."""
   g = ops.Graph()
@@ -163,6 +170,7 @@
   write_graph(tfmatmul, FLAGS.out_dir)
   write_graph(tfmatmulandadd, FLAGS.out_dir)
   write_graph(tfsplits, FLAGS.out_dir)
+  write_graph(tftop_k, FLAGS.out_dir)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
new file mode 100644
index 0000000..6b4ac2d
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
@@ -0,0 +1,13 @@
+# Text form of tensorflow.tf2xla.Config proto.
+feed {
+  id { node_name: "x" }
+  shape {
+    dim { size: 5 }
+  }
+}
+fetch {
+  id { node_name: "values" }
+}
+fetch {
+  id { node_name: "indices" }
+}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 7ac90fb..f10852c 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -29,6 +29,7 @@
 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h"
 #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
 #include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
@@ -448,6 +449,30 @@
   EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
 }
 
+TEST(TFCompileTest, TopK) {
+  Eigen::ThreadPool tp(1);
+  Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+  TopKComp fn;
+
+  fn.set_thread_pool(&device);
+  // x = [4, 1, 4, 4, 3]
+  fn.arg0(0) = 4;
+  fn.arg0(1) = 1;
+  fn.arg0(2) = 4;
+  fn.arg0(3) = 4;
+  fn.arg0(4) = 3;
+
+  EXPECT_TRUE(fn.Run());
+  EXPECT_EQ(fn.error_msg(), "");
+  const int32 expected_values[] = {4, 4};
+  const int32 expected_indices[] = {0, 2};
+  EXPECT_EQ(expected_values[0], fn.result0(0));
+  EXPECT_EQ(expected_values[1], fn.result0(1));
+  EXPECT_EQ(expected_indices[0], fn.result1(0));
+  EXPECT_EQ(expected_indices[1], fn.result1(1));
+}
+
 TEST(TFCompileTest, AssertEqAndReturnDiff) {
   // Assert is converted into a no-op in XLA, so there is no failure even if the
   // two args are different.
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 792b7fe..859c84b 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -273,6 +273,7 @@
             "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
             "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
             "//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
+            "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
             "//tensorflow/compiler/xla/service/cpu:runtime_matmul",
             "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
             "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 7d5db71..4e18472 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -26,6 +26,7 @@
 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
 load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
 load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
 
 # Target that bundles up the XLA CPU and GPU JIT devices.
 cc_library(
@@ -50,7 +51,7 @@
     visibility = ["//visibility:public"],
     deps = [
         ":jit_compilation_passes",
-        "//tensorflow/compiler/jit/kernels:xla_launch_op",
+        "//tensorflow/compiler/jit/kernels:xla_ops",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/compiler/xla/service:cpu_plugin",
     ],
@@ -62,7 +63,7 @@
     visibility = ["//visibility:public"],
     deps = if_cuda([
         ":jit_compilation_passes",
-        "//tensorflow/compiler/jit/kernels:xla_launch_op",
+        "//tensorflow/compiler/jit/kernels:xla_ops",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/compiler/xla/service:gpu_plugin",
     ]),
@@ -76,7 +77,7 @@
     deps = [
         ":jit_compilation_passes",
         ":xla_device",
-        "//tensorflow/compiler/jit/kernels:xla_launch_op",
+        "//tensorflow/compiler/jit/kernels:xla_ops",
         "//tensorflow/compiler/jit/legacy_flags:xla_device_flags",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
@@ -94,7 +95,7 @@
     deps = [
         ":jit_compilation_passes",
         ":xla_device",
-        "//tensorflow/compiler/jit/kernels:xla_launch_op",
+        "//tensorflow/compiler/jit/kernels:xla_ops",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/compiler/xla/service:gpu_plugin",  # buildcleaner: keep
@@ -111,7 +112,7 @@
     deps = [
         ":jit_compilation_passes",
         ":xla_device",
-        "//tensorflow/compiler/jit/kernels:xla_launch_op",
+        "//tensorflow/compiler/jit/kernels:xla_ops",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/compiler/xla/service:interpreter_plugin",  # buildcleaner: keep
@@ -280,7 +281,7 @@
     deps = [
         ":common",
         ":compilation_passes",
-        "//tensorflow/compiler/jit/kernels:xla_launch_op",
+        "//tensorflow/compiler/jit/kernels:xla_ops",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
@@ -341,7 +342,7 @@
         "//tensorflow/cc:ops",
         "//tensorflow/cc:resource_variable_ops",
         "//tensorflow/cc:sendrecv_ops",
-        "//tensorflow/compiler/jit/kernels:xla_launch_op",
+        "//tensorflow/compiler/jit/kernels:xla_ops",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/core:core_cpu",
@@ -359,18 +360,20 @@
 cc_library(
     name = "compilation_passes",
     srcs = [
-        "build_xla_launch_ops_pass.cc",
+        "build_xla_ops_pass.cc",
         "deadness_analysis.cc",
         "deadness_analysis_internal.h",
         "encapsulate_subgraphs_pass.cc",
+        "encapsulate_xla_computations_pass.cc",
         "mark_for_compilation_pass.cc",
         "mark_for_compilation_pass_test_helper.cc",
         "partially_decluster_pass.cc",
     ],
     hdrs = [
-        "build_xla_launch_ops_pass.h",
+        "build_xla_ops_pass.h",
         "deadness_analysis.h",
         "encapsulate_subgraphs_pass.h",
+        "encapsulate_xla_computations_pass.h",
         "mark_for_compilation_pass.h",
         "mark_for_compilation_pass_test_helper.h",
         "partially_decluster_pass.h",
@@ -397,6 +400,7 @@
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/kernels:bounds_check",
         "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -456,7 +460,7 @@
         "//tensorflow/cc:function_ops",
         "//tensorflow/cc:ops",
         "//tensorflow/cc:sendrecv_ops",
-        "//tensorflow/compiler/jit/kernels:xla_launch_op",
+        "//tensorflow/compiler/jit/kernels:xla_ops",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/core:core_cpu",
@@ -475,6 +479,7 @@
     size = "small",
     srcs = [
         "encapsulate_subgraphs_pass_test.cc",
+        "encapsulate_xla_computations_pass_test.cc",
         "mark_for_compilation_pass_test.cc",
         "partially_decluster_pass_test.cc",
     ],
@@ -489,8 +494,10 @@
         "//tensorflow/cc:ops",
         "//tensorflow/cc:resource_variable_ops",
         "//tensorflow/cc:sendrecv_ops",
-        "//tensorflow/compiler/jit/kernels:xla_launch_op",
+        "//tensorflow/compiler/jit/kernels:xla_ops",
+        "//tensorflow/compiler/tf2xla:test_util",
         "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
@@ -518,7 +525,7 @@
         "//tensorflow/cc:cc_ops_internal",
         "//tensorflow/cc:function_ops",
         "//tensorflow/cc:ops",
-        "//tensorflow/compiler/jit/kernels:xla_launch_op",
+        "//tensorflow/compiler/jit/kernels:xla_ops",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/core:core_cpu",
@@ -593,6 +600,44 @@
     ],
 )
 
+cc_library(
+    name = "node_matchers",
+    testonly = True,
+    srcs = ["node_matchers.cc"],
+    hdrs = ["node_matchers.h"],
+    deps = [
+        "//tensorflow/cc:ops",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/core:graph",
+        "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+tf_cc_test(
+    name = "node_matchers_test",
+    srcs = ["node_matchers_test.cc"],
+    deps = [
+        ":node_matchers",
+        "//tensorflow/cc:cc_ops",
+        "//tensorflow/cc:ops",
+        "//tensorflow/core:ops",
+        "//tensorflow/core:test_main",
+    ],
+)
+
+tf_custom_op_py_library(
+    name = "xla_ops_py",
+    kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],
+    visibility = [
+        ":friends",
+    ],
+    deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"],
+)
+
 # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
 cc_header_only_library(
     name = "xla_jit_headers_lib",
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
deleted file mode 100644
index b17ff58..0000000
--- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
+++ /dev/null
@@ -1,142 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
-#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/common_runtime/optimization_registry.h"
-#include "tensorflow/core/framework/graph_def_util.h"
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/graph/algorithm.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/public/version.h"
-
-namespace tensorflow {
-
-static Status BuildLaunchNode(
-    const string& nodename, const string& function_name,
-    const AttrValueMap& function_attr, const string& device_name,
-    const DataTypeVector& constant_dtypes, int num_resources,
-    const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes,
-    Graph* graph, Node** node) {
-  NodeDef def;
-  def.set_name(graph->NewName(nodename));
-  def.set_op("XlaLaunch");
-  def.set_device(device_name);
-  AddNodeAttr("Tconstants", constant_dtypes, &def);
-  AddNodeAttr("Targs", arg_dtypes, &def);
-  AddNodeAttr("Nresources", num_resources, &def);
-  AddNodeAttr("Tresults", result_dtypes, &def);
-  NameAttrList function;
-  function.set_name(function_name);
-  *function.mutable_attr() = function_attr;
-  AddNodeAttr("function", function, &def);
-
-  Status status;
-  *node = graph->AddNode(def, &status);
-  return status;
-}
-
-static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
-  VLOG(2) << "Replacing " << node->name() << " with XlaLaunch";
-
-  int num_constant_args, num_resource_args;
-  TF_RETURN_IF_ERROR(
-      GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args));
-  TF_RETURN_IF_ERROR(
-      GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args));
-
-  if (num_constant_args < 0 || num_resource_args < 0 ||
-      num_constant_args + num_resource_args > node->num_inputs()) {
-    return errors::InvalidArgument(
-        "Invalid number of constant/resource arguments to XLA kernel.");
-  }
-  const int num_nonconst_args =
-      node->num_inputs() - num_constant_args - num_resource_args;
-
-  DataTypeVector const_dtypes(node->input_types().begin(),
-                              node->input_types().begin() + num_constant_args);
-  DataTypeVector arg_dtypes(
-      node->input_types().begin() + num_constant_args,
-      node->input_types().begin() + num_constant_args + num_nonconst_args);
-
-  // Build a XlaLaunch operator to execute the function body.
-  Node* launch_node;
-  TF_RETURN_IF_ERROR(BuildLaunchNode(
-      graph->NewName(node->name()), node->type_string(), node->def().attr(),
-      node->requested_device(), const_dtypes, num_resource_args, arg_dtypes,
-      node->output_types(), graph, &launch_node));
-  launch_node->set_assigned_device_name(node->assigned_device_name());
-
-  // Copy incoming edges to the launch node.
-  for (const Edge* edge : node->in_edges()) {
-    if (edge->IsControlEdge()) {
-      graph->AddControlEdge(edge->src(), launch_node);
-    } else {
-      graph->AddEdge(edge->src(), edge->src_output(), launch_node,
-                     edge->dst_input());
-    }
-  }
-
-  // Copy outgoing edges to the launch node.
-  std::vector<const Edge*> out_edges(node->out_edges().begin(),
-                                     node->out_edges().end());
-  for (const Edge* edge : out_edges) {
-    Node* dst = edge->dst();
-    int src_output = edge->src_output();
-    int dst_input = edge->dst_input();
-    graph->RemoveEdge(edge);
-
-    if (edge->IsControlEdge()) {
-      graph->AddControlEdge(launch_node, dst);
-    } else {
-      graph->AddEdge(launch_node, src_output, dst, dst_input);
-    }
-  }
-  graph->RemoveNode(node);
-
-  return Status::OK();
-}
-
-Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
-  Graph* graph = options.graph->get();
-
-  for (Node* n : graph->op_nodes()) {
-    // In all cases, only try to compile computational nodes.
-    if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
-      continue;
-    }
-
-    // Only compile nodes that are marked for compilation by the
-    // compilation-marking pass (via 'attr_name').
-    if (IsXlaCompiledKernel(*n)) {
-      TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n));
-    }
-  }
-
-  if (VLOG_IS_ON(1)) {
-    dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph,
-                                options.flib_def);
-  }
-  return Status::OK();
-}
-}  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
new file mode 100644
index 0000000..a6086f3
--- /dev/null
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -0,0 +1,187 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+static Status BuildXlaCompileNode(
+    const string& nodename, const string& function_name,
+    const AttrValueMap& function_attr, const string& device_name,
+    const DataTypeVector& constant_dtypes, int num_resources,
+    const DataTypeVector& arg_dtypes, Graph* graph, Node** node) {
+  NodeDef def;
+  def.set_name(graph->NewName(nodename));
+  def.set_op("_XlaCompile");
+  def.set_device(device_name);
+  AddNodeAttr("Tconstants", constant_dtypes, &def);
+  AddNodeAttr("Targs", arg_dtypes, &def);
+  AddNodeAttr("Nresources", num_resources, &def);
+  NameAttrList function;
+  function.set_name(function_name);
+  *function.mutable_attr() = function_attr;
+  AddNodeAttr("function", function, &def);
+
+  Status status;
+  *node = graph->AddNode(def, &status);
+  return status;
+}
+
+static Status BuildXlaRunNode(const string& nodename, const string& device_name,
+                              const DataTypeVector& constant_dtypes,
+                              const DataTypeVector& arg_dtypes,
+                              const DataTypeVector& result_dtypes, Graph* graph,
+                              Node** node) {
+  NodeDef def;
+  def.set_name(graph->NewName(nodename));
+  def.set_op("_XlaRun");
+  def.set_device(device_name);
+  AddNodeAttr("Tconstants", constant_dtypes, &def);
+  AddNodeAttr("Targs", arg_dtypes, &def);
+  AddNodeAttr("Tresults", result_dtypes, &def);
+
+  Status status;
+  *node = graph->AddNode(def, &status);
+  return status;
+}
+
+static Status GetXlaAttrs(Node* node, int* num_constant_args,
+                          int* num_resource_args, DataTypeVector* const_dtypes,
+                          DataTypeVector* arg_dtypes) {
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, num_constant_args));
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, num_resource_args));
+
+  if (*num_constant_args < 0 || *num_resource_args < 0 ||
+      *num_constant_args + *num_resource_args > node->num_inputs()) {
+    return errors::InvalidArgument(
+        "Invalid number of constant/resource arguments to XLA kernel.");
+  }
+
+  const int num_nonconst_args =
+      node->num_inputs() - *num_constant_args - *num_resource_args;
+
+  const DataTypeVector& input_types = node->input_types();
+  std::copy(input_types.begin(), input_types.begin() + *num_constant_args,
+            std::back_inserter(*const_dtypes));
+  std::copy(input_types.begin() + *num_constant_args,
+            input_types.begin() + *num_constant_args + num_nonconst_args,
+            std::back_inserter(*arg_dtypes));
+  return Status::OK();
+}
+
+static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node) {
+  for (const Edge* edge : old_node->in_edges()) {
+    if (edge->IsControlEdge()) {
+      g->AddControlEdge(edge->src(), new_node);
+    } else {
+      g->AddEdge(edge->src(), edge->src_output(), new_node, edge->dst_input());
+    }
+  }
+}
+
+static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
+  std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
+                                     old_node->out_edges().end());
+  for (const Edge* edge : out_edges) {
+    Node* dst = edge->dst();
+    int src_output = edge->src_output();
+    int dst_input = edge->dst_input();
+    g->RemoveEdge(edge);
+
+    if (edge->IsControlEdge()) {
+      g->AddControlEdge(new_node, dst);
+    } else {
+      g->AddEdge(new_node, src_output, dst, dst_input);
+    }
+  }
+}
+
+static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) {
+  int num_constant_args, num_resource_args;
+  DataTypeVector const_dtypes;
+  DataTypeVector arg_dtypes;
+
+  TF_RETURN_IF_ERROR(GetXlaAttrs(n, &num_constant_args, &num_resource_args,
+                                 &const_dtypes, &arg_dtypes));
+
+  Node *compile_node, *run_node;
+
+  TF_RETURN_IF_ERROR(BuildXlaCompileNode(
+      n->name(), n->type_string(), n->def().attr(), n->requested_device(),
+      const_dtypes, num_resource_args, arg_dtypes, g, &compile_node));
+
+  DataTypeVector arg_dtypes_with_resources = arg_dtypes;
+  for (int i = 0; i < num_resource_args; i++) {
+    arg_dtypes_with_resources.push_back(DT_RESOURCE);
+  }
+
+  TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(),
+                                     const_dtypes, arg_dtypes_with_resources,
+                                     n->output_types(), g, &run_node));
+
+  compile_node->set_assigned_device_name(n->assigned_device_name());
+  run_node->set_assigned_device_name(n->assigned_device_name());
+
+  CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node);
+  CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
+
+  // The compilation_key output.
+  g->AddEdge(compile_node, 0, run_node, n->num_inputs());
+
+  MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
+  g->RemoveNode(n);
+
+  return Status::OK();
+}
+
+Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
+  Graph* graph = options.graph->get();
+
+  for (Node* n : graph->op_nodes()) {
+    // In all cases, only try to compile computational nodes.
+    if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
+      continue;
+    }
+
+    // Only compile nodes that are marked for compilation by the
+    // compilation-marking pass (via 'attr_name').
+    if (IsXlaCompiledKernel(*n)) {
+      TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndRun(graph, n));
+    }
+  }
+
+  if (VLOG_IS_ON(1)) {
+    dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def);
+  }
+  return Status::OK();
+}
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h
similarity index 71%
rename from tensorflow/compiler/jit/build_xla_launch_ops_pass.h
rename to tensorflow/compiler/jit/build_xla_ops_pass.h
index 1dfea93..1dd38fa 100644
--- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.h
@@ -13,19 +13,21 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
-#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
 
 #include "tensorflow/core/common_runtime/optimization_registry.h"
 #include "tensorflow/core/lib/core/status.h"
 
 namespace tensorflow {
 
-class BuildXlaLaunchOpsPass : public GraphOptimizationPass {
+// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and
+// executes (using XLA) TF function calls marked with "_XlaCompiledKernel".
+class BuildXlaOpsPass : public GraphOptimizationPass {
  public:
   Status Run(const GraphOptimizationPassOptions& options) override;
 };
 
 }  // namespace tensorflow
 
-#endif  // TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+#endif  // TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 56b034a..6f1ff85 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -16,7 +16,7 @@
 
 #include "absl/memory/memory.h"
 #include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
 #include "tensorflow/compiler/tf2xla/const_analysis.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index ae7a22f..e0632ff 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -22,6 +22,7 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -58,6 +59,22 @@
 const char* const kXlaHostTransferSequencerAttr =
     "_xla_host_transfer_sequencer";
 
+void SortControlInputs(GraphDef* gdef) {
+  int64 num_nodes = gdef->node_size();
+  for (int64 i = 0; i < num_nodes; ++i) {
+    NodeDef* node = gdef->mutable_node(i);
+    // Stable sort control inputs and leave the order of data inputs unchanged.
+    std::stable_sort(node->mutable_input()->begin(),
+                     node->mutable_input()->end(),
+                     [](const string& a, const string& b) {
+                       bool a_is_control = absl::StartsWith(a, "^");
+                       bool b_is_control = absl::StartsWith(b, "^");
+                       return (!a_is_control && b_is_control) ||
+                              (a_is_control && b_is_control && a < b);
+                     });
+  }
+}
+
 namespace {
 
 bool AreAllParentsGuaranteedConst(
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
index 9265895..90354a8 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
@@ -102,6 +102,12 @@
 // Name of the attribute containing the number of resource variable arguments.
 extern const char* const kXlaNumResourceArgsAttr;
 
+// Sorts each node's control inputs by their names. This guarantees that for two
+// structually equivalent GraphDefs, we get the same traversal ordering on
+// node's control input fields.
+// TODO(hpucha): Move the utilities to a more appropriate place.
+void SortControlInputs(GraphDef* gdef);
+
 class EncapsulateSubgraphsPass : public GraphOptimizationPass {
  public:
   Status Run(const GraphOptimizationPassOptions& options) override;
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
new file mode 100644
index 0000000..97ef8cd
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -0,0 +1,360 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+
+namespace tensorflow {
+
+const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr =
+    "_xla_compile_id";
+
+namespace {
+
+const char* const kXlaClusterOutput = "XlaClusterOutput";
+
+// Checks if a graph node is marked to be a guaranteed constant.
+bool is_guaranteed_constant(const Node& n) {
+  bool guaranteed_constant = false;
+  if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant)
+           .ok()) {
+    return false;
+  }
+  return guaranteed_constant;
+}
+
+// Finds the `index` of an _Arg or _Retval node.
+Status GetIndexAttr(const Node& n, int num_args, int* index) {
+  TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index));
+  if (*index < 0 || *index >= num_args) {
+    return errors::InvalidArgument("Invalid ", n.type_string(), " number ",
+                                   *index);
+  }
+  return Status::OK();
+}
+
+// Returns the data type of the destination of an edge.
+DataType EdgeType(const Edge* edge) {
+  return edge->dst()->input_type(edge->dst_input());
+}
+
+// Adds the control inputs of `node` to `*deps`.
+void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+  for (const Edge* edge : node.in_edges()) {
+    if (edge->IsControlEdge()) {
+      deps->insert(edge->src());
+    }
+  }
+}
+
+// Adds the control outputs of `node` to `*deps`.
+void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+  for (const Edge* edge : node.out_edges()) {
+    if (edge->IsControlEdge()) {
+      deps->insert(edge->dst());
+    }
+  }
+}
+
+// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts
+// the arguments into the order expected by XlaLaunch computations:
+// 1) arguments
+// 2) resource variable arguments
+// See the documentation of EncapsulateSubgraphsInFunctions for the meaning
+// of the arguments.
+//
+// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed.
+Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
+                       std::unique_ptr<Graph>* graph_ptr,
+                       std::vector<int>* input_permutation,
+                       std::vector<int>* output_permutation,
+                       NodeDef* call_def) {
+  Graph* graph = graph_ptr->get();
+  const int num_args = input_permutation->size();
+  const int num_retvals = output_permutation->size();
+
+  std::vector<Node*> args;
+  std::vector<Node*> retvals;
+  args.reserve(num_args);
+  retvals.reserve(num_retvals);
+  for (Node* n : graph->nodes()) {
+    if (n->type_string() == "_Arg") {
+      // Check if this is a guaranteed constant.
+      if (is_guaranteed_constant(*n)) {
+        return errors::InvalidArgument(
+            "Guaranteed constants are not supported (", n->name(), ")");
+      }
+      args.push_back(n);
+    } else if (n->type_string() == "_Retval") {
+      retvals.push_back(n);
+    }
+  }
+
+  if (std::find(args.begin(), args.end(), nullptr) != args.end()) {
+    return errors::InvalidArgument("Missing or non-consecutive arguments");
+  }
+
+  // Reorders the arguments.
+  std::sort(args.begin(), args.end(), [&](Node* a, Node* b) {
+    // Non-resources appear before resources
+    bool a_is_resource = (a->output_type(0) == DT_RESOURCE);
+    bool b_is_resource = (b->output_type(0) == DT_RESOURCE);
+    // Uses the name as a tiebreaker so the output is deterministic.
+    StringPiece a_name(a->name());
+    StringPiece b_name(b->name());
+    return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name);
+  });
+
+  // Sorts the retvals by name so the order is deterministic.
+  std::sort(retvals.begin(), retvals.end(),
+            [](Node* a, Node* b) { return a->name() < b->name(); });
+
+  // Computes the permutation to produce the correct argument order, and update
+  // the argument indices.
+  int variable_start_index = num_args;
+  for (int i = 0; i < num_args; ++i) {
+    int index;
+    TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index));
+    if (args[i]->output_type(0) == DT_RESOURCE &&
+        variable_start_index == num_args) {
+      variable_start_index = i;
+    }
+    (*input_permutation)[index] = i;
+    args[i]->AddAttr("index", i);
+  }
+  VLOG(4) << "variable_start_index: " << variable_start_index;
+
+  // Computes the permutation to produce the correct retval order, and update
+  // the argument indices.
+  for (int i = 0; i < num_retvals; ++i) {
+    int index;
+    TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index));
+    (*output_permutation)[index] = i;
+    retvals[i]->AddAttr("index", i);
+  }
+
+  AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(),
+              call_def);
+  AddNodeAttr("_variable_start_index", variable_start_index, call_def);
+
+  // Uniquify the function name.
+  GraphDef gdef;
+  graph->ToGraphDef(&gdef);
+
+  // Before serialization, sort each node's control inputs to achieve
+  // determinism. Sorting control inputs could help (but not necessarily) create
+  // a deterministic serialization and fingerprint. Other sources of
+  // nondeterminism include unstable node ordering.
+  SortControlInputs(&gdef);
+  // Fingerprint the function.
+  // Nondeterminism in serialization would not lead to incorrect results, but
+  // may cause spurious cache misses. DeterministicSerialization is a
+  // best-effort deterministic serialization.
+  string serialized;
+  TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized));
+  uint64 fingerprint = Fingerprint64(serialized);
+  LOG(INFO) << "Subgraph fingerprint:" << fingerprint;
+  call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint));
+  return Status::OK();
+}
+
+}  // namespace
+
+/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate(
+    std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
+  // Check for undeclared outputs before Encapsulation, so we can give a better
+  // error message.
+  // TODO(phawkins): merge this with the encapsulation code to avoid the extra
+  // O(n) pass over the edges.
+  for (const Edge* e : (*graph)->edges()) {
+    if (!e->IsControlEdge() &&
+        e->src()->attrs().Find(kXlaClusterAttr) != nullptr &&
+        e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
+        e->dst()->type_string() != kXlaClusterOutput) {
+      return errors::InvalidArgument(
+          "Undeclared output of XLA computation. A common cause of this error "
+          "is variable initializers that depend on the XLA computation. Edge: ",
+          e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
+          e->dst_input());
+    }
+  }
+
+  auto output = absl::make_unique<Graph>((*graph)->op_registry());
+  TF_RETURN_WITH_CONTEXT_IF_ERROR(
+      EncapsulateSubgraphsInFunctions(
+          kXlaClusterAttr, "", **graph, RewriteSubgraph,
+          /*reuse_existing_functions=*/true, &output, flib_def),
+      "EncapsulateXlaComputationsPass failed");
+  graph->swap(output);
+  return Status::OK();
+}
+
+/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps(
+    Graph* graph) {
+  // Finds all of the XlaLaunch function calls, to avoid mutating the graph
+  // while iterating.
+  std::vector<Node*> launch_nodes;
+  for (Node* n : graph->nodes()) {
+    string name;
+    if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) {
+      launch_nodes.push_back(n);
+    }
+  }
+
+  // Replaces each launch function call together with its neighboring
+  // XlaClusterOutput nodes with a XlaLaunch node.
+  for (Node* launch : launch_nodes) {
+    int variable_start_index;
+    TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index",
+                                   &variable_start_index));
+
+    std::vector<const Edge*> in_edges;
+    TF_RETURN_IF_ERROR(launch->input_edges(&in_edges));
+
+    const int num_inputs = in_edges.size();
+    const int num_variables = num_inputs - variable_start_index;
+    const int num_args = variable_start_index;
+
+    VLOG(4) << "Launch node '" << launch->name() << "'"
+            << " input edges: " << in_edges.size() << " num_args: " << num_args
+            << " num_variables: " << num_variables;
+
+    std::vector<Node*> nodes_to_remove = {launch};
+
+    // Data and control inputs to the new XlaLaunch node.
+    std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
+    gtl::FlatSet<Node*> control_inputs;
+    DataTypeVector arg_types(num_args);
+
+    AddControlInputs(*launch, &control_inputs);
+
+    for (int i = 0; i < num_args; ++i) {
+      const Edge* edge = in_edges[i];
+      data_inputs[i] = {edge->src(), edge->src_output()};
+      arg_types[i] = EdgeType(edge);
+    }
+
+    // Appends the variable inputs.
+    for (int i = 0; i < num_variables; ++i) {
+      int pos = variable_start_index + i;
+      const Edge* edge = in_edges[pos];
+      data_inputs[pos] = {edge->src(), edge->src_output()};
+    }
+
+    // Outputs.
+    const int num_outputs = launch->output_types().size();
+    gtl::FlatSet<Node*> control_outputs;
+    std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs);
+    DataTypeVector output_types(num_outputs);
+
+    for (const Edge* le : launch->out_edges()) {
+      if (le->IsControlEdge()) {
+        control_outputs.insert(le->dst());
+      } else {
+        TF_RET_CHECK(le->src_output() < num_outputs);
+        Node* output_node = le->dst();
+
+        TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput)
+            << le->DebugString();
+        nodes_to_remove.push_back(output_node);
+
+        for (const Edge* oe : output_node->out_edges()) {
+          TF_RET_CHECK(!oe->IsControlEdge());
+          data_outputs[le->src_output()].push_back(
+              {oe->dst(), oe->dst_input()});
+        }
+        output_types[le->src_output()] = output_node->input_type(0);
+
+        AddControlOutputs(*output_node, &control_outputs);
+      }
+    }
+
+    NodeDef def;
+    def.set_name(launch->name());
+
+    // Target the XLA CPU/GPU backends.
+    VLOG(2) << "Replacing with XlaLaunch";
+    def.set_op("XlaLaunch");
+    AddNodeAttr("Tconstants", DataTypeVector{}, &def);
+    AddNodeAttr("Targs", arg_types, &def);
+    AddNodeAttr("Nresources", num_variables, &def);
+    AddNodeAttr("Tresults", output_types, &def);
+    NameAttrList function;
+    function.set_name(launch->type_string());
+    AddNodeAttr("function", function, &def);
+
+    for (Node* node : nodes_to_remove) {
+      VLOG(2) << "Deleting node " << node->DebugString();
+      // Ensure that we do not attempt to add control edges to nodes that are
+      // deleted.
+      control_inputs.erase(node);
+      control_outputs.erase(node);
+      graph->RemoveNode(node);
+    }
+
+    Status status;
+    Node* xla_launch = graph->AddNode(def, &status);
+    if (!status.ok()) {
+      return status;
+    }
+    for (int i = 0; i < data_inputs.size(); ++i) {
+      graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch,
+                     i);
+    }
+    for (Node* n : control_inputs) {
+      graph->AddControlEdge(n, xla_launch);
+    }
+    for (int i = 0; i < data_outputs.size(); ++i) {
+      for (const auto& successor : data_outputs[i]) {
+        graph->AddEdge(xla_launch, i, successor.first, successor.second);
+      }
+    }
+    for (Node* n : control_outputs) {
+      graph->AddControlEdge(xla_launch, n);
+    }
+  }
+  return Status::OK();
+}
+
+Status EncapsulateXlaComputationsPass::Run(
+    const GraphOptimizationPassOptions& options) {
+  VLOG(1) << "EncapsulateXlaComputations(): "
+          << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before",
+                                         **options.graph, options.flib_def);
+
+  TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def));
+  VLOG(1) << "EncapsulateXlaComputations() half-way: "
+          << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway",
+                                         **options.graph, options.flib_def);
+
+  TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get()));
+  VLOG(1) << "EncapsulateXlaComputations() finished: "
+          << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after",
+                                         **options.graph, options.flib_def);
+  return Status::OK();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
new file mode 100644
index 0000000..99e9dfd
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================*/
+// Rewrites computations generated by the xla.compile() Python code into
+// XlaLaunch nodes.
+//
+// xla.compile() does two main things:
+// a) marks operators that make up an XLA computation with the attribute
+//    _xla_compile_id=XYZ, where XYZ is a unique key.
+// b) adds XlaClusterOutput nodes to represent outputs of the computation.
+//    These nodes are not marked with the _xla_compile_id attribute.
+
+#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/env.h"
+
+    namespace tensorflow {
+
+// Encapsulates nodes marked with the _xla_compile_id attribute into
+// XlaLaunch operators.
+class EncapsulateXlaComputationsPass : public GraphOptimizationPass {
+ public:
+  static const char* const kXlaClusterAttr;  // _xla_compile_id
+
+  Status Run(const GraphOptimizationPassOptions& options) override;
+
+  // The following methods are public only for unit tests.
+
+  // This pass has two stages:
+  // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes
+  //    marked with the same _xla_compile_id attribute into functions. These
+  //    functions contain the computations to be passed to XlaLaunch. During
+  //    encapsulation, we sort the arguments into the order expected by
+  //    XlaLaunch.
+  static Status Encapsulate(std::unique_ptr<Graph>* graph,
+                            FunctionLibraryDefinition* flib_def);
+
+  // b) we rewrite the function calls generated in phase (a) into XlaLaunch
+  //    operators. We also convert the XlaClusterOutput output nodes of the
+  //    function call into the outputs of the XlaLaunch operator.
+  static Status BuildXlaLaunchOps(Graph* graph);
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
new file mode 100644
index 0000000..f643fb0
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
@@ -0,0 +1,346 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h"
+#include "tensorflow/compiler/tf2xla/test_util.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/equal_graph_def.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+
+static std::unique_ptr<Graph> MakeOuterGraph(
+    const FunctionLibraryDefinition& flib_def, const string& function) {
+  Scope scope = Scope::NewRootScope().ExitOnError();
+  TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto()));
+
+  auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+  auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+  auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+  auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+  auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+  auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+  auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+  NodeDef def;
+  TF_CHECK_OK(
+      NodeDefBuilder("launch0", function, &flib_def)
+          .Input(a.node()->name(), 0, DT_INT32)
+          .Input(b.node()->name(), 0, DT_FLOAT)
+          .Input(c.node()->name(), 0, DT_INT32)
+          .Input(d.node()->name(), 0, DT_FLOAT)
+          .Input(u.node()->name(), 0, DT_RESOURCE)
+          .Input(v.node()->name(), 0, DT_RESOURCE)
+          .Input(w.node()->name(), 0, DT_RESOURCE)
+          .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
+          .Attr("_variable_start_index", 4)
+          .Finalize(&def));
+
+  Status status;
+  Node* launch = scope.graph()->AddNode(def, &status);
+  TF_CHECK_OK(status);
+  TF_CHECK_OK(scope.DoShapeInference(launch));
+  scope.graph()->AddEdge(a.node(), 0, launch, 0);
+  scope.graph()->AddEdge(b.node(), 0, launch, 1);
+  scope.graph()->AddEdge(c.node(), 0, launch, 2);
+  scope.graph()->AddEdge(d.node(), 0, launch, 3);
+  scope.graph()->AddEdge(u.node(), 0, launch, 4);
+  scope.graph()->AddEdge(v.node(), 0, launch, 5);
+  scope.graph()->AddEdge(w.node(), 0, launch, 6);
+
+  auto out0 =
+      ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0));
+  auto out1 =
+      ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1));
+  auto out2 =
+      ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2));
+  auto out3 =
+      ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3));
+
+  auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
+  auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
+  auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
+  auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
+  auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
+  auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
+
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+  TF_CHECK_OK(scope.ToGraph(graph.get()));
+  return graph;
+}
+
+// Makes an encapsulate body graph for use in tests.
+static std::unique_ptr<Graph> MakeBodyGraph() {
+  Scope scope = Scope::NewRootScope().ExitOnError();
+
+  auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0);
+  auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1);
+  auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2);
+  auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3);
+
+  auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4);
+  auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5);
+  auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
+
+  auto add_attrs = [](Node* node) {
+    node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+  };
+
+  auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
+
+  auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
+  add_attrs(read_u.node());
+  auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
+  add_attrs(read_v.node());
+  auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT);
+  add_attrs(read_w.node());
+
+  auto e = ops::Add(scope.WithOpName("E"), arg0, arg2);
+  add_attrs(e.node());
+  auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
+  add_attrs(f.node());
+  auto g = ops::Add(scope.WithOpName("G"), f, arg3);
+  add_attrs(g.node());
+
+  auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"),
+                           b_identity, 0);
+  auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1);
+  auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2);
+  auto out3 =
+      ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3);
+
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+  TF_CHECK_OK(scope.ToGraph(graph.get()));
+  return graph;
+}
+
+TEST(EncapsulateXlaComputations, DeterministicEncapsulate) {
+  // Test that control edge insertion order doesn't affect the cache key
+  // (cluster name) generated by TPU encapsulate pass.
+  auto get_serialized_graph = [](bool control_input_reversed,
+                                 bool operand_reversed) -> string {
+    FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
+    std::unique_ptr<Graph> graph(new Graph(&flib_def));
+    {
+      Scope scope = Scope::NewRootScope().ExitOnError();
+      auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32);
+      auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32);
+
+      ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1)
+                                    : ops::Add(scope.WithOpName("E"), a1, a0);
+
+      auto add_attrs = [](Node* node) {
+        node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr,
+                      "launch0");
+      };
+      add_attrs(e.node());
+
+      TF_CHECK_OK(scope.ToGraph(graph.get()));
+      auto get_node_in_graph = [&graph](Node* node) {
+        return graph->FindNodeId(node->id());
+      };
+      // Insert control edge in different order. The order should not affect
+      // the encapsulated or serialized graph.
+      if (!control_input_reversed) {
+        graph->AddControlEdge(get_node_in_graph(a0.node()),
+                              get_node_in_graph(e.node()), true);
+        graph->AddControlEdge(get_node_in_graph(a1.node()),
+                              get_node_in_graph(e.node()), true);
+      } else {
+        graph->AddControlEdge(get_node_in_graph(a1.node()),
+                              get_node_in_graph(e.node()), true);
+        graph->AddControlEdge(get_node_in_graph(a0.node()),
+                              get_node_in_graph(e.node()), true);
+      }
+    }
+    TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
+    GraphDef gdef;
+    graph->ToGraphDef(&gdef);
+    // Before serialization, sort control inputs first to remove
+    // nondeterminism.
+    SortControlInputs(&gdef);
+    string serialized;
+    SerializeToStringDeterministic(gdef, &serialized);
+    return serialized;
+  };
+
+  // Changing the order of control input shouldn't affect the graph generated.
+  EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true,
+                                 /*operand_reversed=*/false),
+            get_serialized_graph(/*control_input_reversed=*/false,
+                                 /*operand_reversed=*/false));
+
+  // Changing the order of data input should affect the graph generated.
+  EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false,
+                                 /*operand_reversed=*/true),
+            get_serialized_graph(/*control_input_reversed=*/false,
+                                 /*operand_reversed=*/false));
+}
+
+TEST(EncapsulateXlaComputations, Encapsulate) {
+  FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
+  std::unique_ptr<Graph> graph(new Graph(&flib_def));
+  {
+    Scope scope = Scope::NewRootScope().ExitOnError();
+    auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+    auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+    auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+    auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+    auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+    auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+    auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+    auto add_attrs = [](Node* node) {
+      node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+    };
+
+    auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b);
+    add_attrs(b_identity.node());
+
+    auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT);
+    add_attrs(read_u.node());
+    auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT);
+    add_attrs(read_v.node());
+    auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT);
+    add_attrs(read_w.node());
+
+    auto e = ops::Add(scope.WithOpName("E"), a, c);
+    add_attrs(e.node());
+    auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
+    add_attrs(f.node());
+    auto g = ops::Add(scope.WithOpName("G"), f, d);
+    add_attrs(g.node());
+
+    auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity);
+    auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e);
+    auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g);
+    auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u);
+
+    auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
+    auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
+    auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
+    auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
+    auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
+    auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
+    TF_ASSERT_OK(scope.ToGraph(graph.get()));
+  }
+
+  std::unique_ptr<Graph> graph_copy(new Graph(&flib_def));
+  CopyGraph(*graph, graph_copy.get());
+
+  TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
+
+  std::unordered_map<string, Node*> index = BuildNodeIndex(*graph);
+  string function = index.at("launch0")->type_string();
+
+  // Tests the outer graph is as expected.
+  {
+    std::unique_ptr<Graph> outer = MakeOuterGraph(flib_def, function);
+    GraphDef expected_def;
+    outer->ToGraphDef(&expected_def);
+
+    GraphDef actual_def;
+    graph->ToGraphDef(&actual_def);
+    TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def);
+  }
+
+  // Tests the encapsulated body graph is as expected.
+  {
+    std::unique_ptr<Graph> body = MakeBodyGraph();
+    GraphDef expected_body_def;
+    body->ToGraphDef(&expected_body_def);
+
+    InstantiationResultForTest result;
+    TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result));
+
+    EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT,
+                              DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}),
+              result.arg_types);
+    EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}),
+              result.ret_types);
+    TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef);
+  }
+
+  // Encapsulates the same computation again, verifies we reuse the same
+  // function. Encapsulation should be deterministic to avoid recompilation.
+  TF_ASSERT_OK(
+      EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def));
+  std::unordered_map<string, Node*> index_copy = BuildNodeIndex(*graph_copy);
+  string function_copy = index_copy.at("launch0")->type_string();
+  EXPECT_EQ(function, function_copy);
+}
+
+TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) {
+  std::unique_ptr<Graph> body_graph = MakeBodyGraph();
+  FunctionDefLibrary flib;
+  TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function()));
+
+  FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
+
+  std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "launch0");
+  TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get()));
+
+  Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError();
+  TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
+
+  auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+  auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+  auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+  auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+  auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+  auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+  auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+  NameAttrList function;
+  function.set_name("launch0");
+  auto launch = ops::XlaLaunch(
+      scope.WithOpName("launch0"), std::initializer_list<Input>{},
+      std::initializer_list<Input>{a, b, c, d},
+      std::initializer_list<Input>{u, v, w},
+      DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);
+
+  auto consumer0_a =
+      ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]);
+  auto consumer0_b =
+      ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]);
+  auto consumer0_c =
+      ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]);
+  auto consumer1 =
+      ops::Identity(scope.WithOpName("consumer1"), launch.results[1]);
+  auto consumer2 =
+      ops::Identity(scope.WithOpName("consumer2"), launch.results[2]);
+  auto consumer3 =
+      ops::Identity(scope.WithOpName("consumer3"), launch.results[3]);
+
+  GraphDef expected_def;
+  TF_ASSERT_OK(scope.ToGraphDef(&expected_def));
+
+  GraphDef actual_def;
+  graph->ToGraphDef(&actual_def);
+  TF_EXPECT_GRAPH_EQ(expected_def, actual_def);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index 5dcf754..085c0e5 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -13,8 +13,9 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
 #include "tensorflow/compiler/jit/partially_decluster_pass.h"
 #include "tensorflow/core/common_runtime/optimization_registry.h"
@@ -23,6 +24,11 @@
 
 // PRE_PLACEMENT passes:
 
+// EncapsulateXlaComputationsPass rewrites computations generated by the
+// xla.compile() Python code into XlaLaunch nodes.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26,
+                      EncapsulateXlaComputationsPass);
+
 // from
 // third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
 // FunctionalizeControlFlowPass: 27
@@ -32,7 +38,8 @@
 // control flow structure (XlaIf/XlaWhile). Following passes must
 // handle those FunctionDef correctly.
 
-// POST_REWRITE_FOR_EXEC passes:
+// POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA:
+
 REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
                       MarkForCompilationPass);
 
@@ -48,6 +55,6 @@
 
 // Must run after EncapsulateSubgraphsPass.
 REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
-                      BuildXlaLaunchOpsPass);
+                      BuildXlaOpsPass);
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 253a5d2..0839f1c 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -7,9 +7,9 @@
 )
 
 cc_library(
-    name = "xla_launch_op",
-    srcs = ["xla_launch_op.cc"],
-    hdrs = ["xla_launch_op.h"],
+    name = "xla_ops",
+    srcs = ["xla_ops.cc"],
+    hdrs = ["xla_ops.h"],
     deps = [
         "//tensorflow/compiler/jit:common",
         "//tensorflow/compiler/jit:xla_compilation_cache",
@@ -26,6 +26,7 @@
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
         "//tensorflow/core/kernels:variable_ops",
+        "@com_google_absl//absl/memory",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
deleted file mode 100644
index b6f2f63..0000000
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ /dev/null
@@ -1,276 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
-
-#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/xla_launch_util.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/variable_ops.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-
-namespace tensorflow {
-
-XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
-                                       const std::vector<int>& constants,
-                                       const std::vector<int>& resources,
-                                       const NameAttrList& function)
-    : OpKernel(ctx),
-      constants_(constants),
-      resources_(resources),
-      device_type_(ctx->device_type()),
-      function_(function) {
-  if (device_type_ == DeviceType(DEVICE_CPU)) {
-    platform_id_ = se::host::kHostPlatformId;
-  } else if (device_type_ == DeviceType(DEVICE_GPU)) {
-    platform_id_ = ctx->device()
-                       ->tensorflow_gpu_device_info()
-                       ->stream->parent()
-                       ->platform()
-                       ->id();
-  } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata_).ok()) {
-    use_multiple_streams_ = xla_device_metadata_->UseMultipleStreams();
-    platform_id_ = xla_device_metadata_->platform()->id();
-  }
-}
-
-Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
-                                                 XlaCompilationCache** cache) {
-  if (xla_device_metadata_) {
-    *cache = new XlaCompilationCache(xla_device_metadata_->client(),
-                                     xla_device_metadata_->jit_device_type());
-    return Status::OK();
-  }
-
-  auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_);
-  if (!platform.ok()) {
-    return platform.status();
-  }
-  xla::LocalClientOptions client_options;
-  client_options.set_platform(platform.ValueOrDie());
-  client_options.set_intra_op_parallelism_threads(
-      ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
-  auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
-  if (!client.ok()) {
-    return client.status();
-  }
-  const XlaOpRegistry::DeviceRegistration* registration;
-  if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(),
-                                           &registration)) {
-    return errors::InvalidArgument("No JIT device registered for ",
-                                   device_type_.type());
-  }
-  *cache = new XlaCompilationCache(
-      client.ValueOrDie(), DeviceType(registration->compilation_device_name));
-  return Status::OK();
-}
-
-void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
-  VLOG(1) << "XlaLocalLaunchOpBase::Compute "
-          << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
-  // We store information about the JIT-compiled XLA computation
-  // in the ResourceMgr.
-  ResourceMgr* rm = ctx->resource_manager();
-  OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
-
-  se::Stream* stream =
-      ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
-
-  XlaCompilationCache* cache;
-  OP_REQUIRES_OK(ctx, rm->LookupOrCreate<XlaCompilationCache>(
-                          rm->default_container(), "xla_cache", &cache,
-                          [this, ctx](XlaCompilationCache** cache) {
-                            return BuildCompilationCache(ctx, cache);
-                          }));
-  // Hold the reference to the JIT during evaluation. (We could probably
-  // free it sooner because the ResourceMgr will retain a reference, but
-  // this is more obviously correct.)
-  core::ScopedUnref cache_ref(cache);
-
-  std::map<int, OptionalTensor> variables =
-      SnapshotResourceVariables(ctx, resources_);
-
-  xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
-
-  XlaAllocator local_xla_allocator(client->backend().platform(),
-                                   ctx->device()->GetAllocator({}));
-  xla::DeviceMemoryAllocator* xla_allocator;
-  // If we are on an XlaDevice, use the underlying XLA platform's allocator
-  // directly. We could use the StreamExecutor's allocator which may
-  // theoretically be more correct, but XLA returns a nice OOM message in a
-  // Status and StreamExecutor does not.
-  //
-  // Importantly we can't use ctx->device()->GetAllocator() as the allocator
-  // (which local_xla_allocator above uses) as on an XlaDevice, this is a
-  // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a
-  // real allocator to allocate real buffers.
-  if (xla_device_metadata_) {
-    xla_allocator = client->backend().memory_allocator();
-  } else {
-    xla_allocator = &local_xla_allocator;
-  }
-
-  XlaCompiler::Options options;
-  options.client = client;
-  if (ctx->op_device_context() != nullptr) {
-    options.device_ordinal =
-        ctx->op_device_context()->stream()->parent()->device_ordinal();
-  }
-  options.device_type = cache->device_type();
-  options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
-  options.graph_def_version = ctx->function_library()->graph_def_version();
-  options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
-  options.device_allocator = xla_allocator;
-  if (xla_device_metadata_) {
-    options.shape_representation_fn =
-        xla_device_metadata_->shape_representation_fn();
-  }
-
-  const XlaCompiler::CompilationResult* kernel;
-  xla::LocalExecutable* executable;
-
-  std::map<int, Tensor> constant_args;
-  for (int i : constants_) {
-    constant_args.insert({i, ctx->input(i)});
-  }
-  XlaCompiler::CompileOptions compile_options;
-  compile_options.is_entry_computation = true;
-  // If we resolve constants we never emit them on the device, meaning that if
-  // they are needed by a following computation the host has to transfer
-  // them. Not resolving constants is expected to be faster than resolving
-  // constants.
-  compile_options.resolve_compile_time_constants = true;
-  // Optimization: where possible, have the computation return a naked array
-  // rather than a one-element tuple.
-  compile_options.always_return_tuple = false;
-
-  OP_REQUIRES_OK(
-      ctx, cache->Compile(options, function_, constant_args, variables, ctx,
-                          &kernel, &executable, compile_options));
-
-  VLOG(1) << "Executing XLA Computation...";
-
-  XlaComputationLaunchContext launch_context(
-      client, xla_allocator,
-      /*allocate_xla_tensors=*/xla_device_metadata_ != nullptr,
-      use_multiple_streams_);
-  launch_context.PopulateInputs(ctx, kernel, variables);
-
-  // Execute the computation.
-  VLOG(2) << "Executing computation.";
-  xla::ExecutableRunOptions run_options;
-  run_options.set_stream(stream);
-  run_options.set_allocator(xla_allocator);
-  run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
-  run_options.set_rng_seed(GetXLARandomSeed());
-  Env* env = Env::Default();
-  auto start_time = env->NowMicros();
-
-  auto run_result = executable->Run(launch_context.arguments(), run_options);
-  OP_REQUIRES(ctx, run_result.ok(), run_result.status());
-
-  auto elapsed = env->NowMicros() - start_time;
-  VLOG(2) << "Elapsed time: " << elapsed << "us";
-
-  OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
-                          ctx, kernel, run_result.ConsumeValueOrDie()));
-  VLOG(1) << "Done";
-}
-
-namespace {
-
-// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
-// in error case, it returns RET instead of void.
-#define OP_REQUIRES_OK_RETURN(CTX, RET, ...)                \
-  do {                                                      \
-    ::tensorflow::Status _s(__VA_ARGS__);                   \
-    if (!TF_PREDICT_TRUE(_s.ok())) {                        \
-      (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
-      return RET;                                           \
-    }                                                       \
-  } while (0)
-
-// Helper static functions to construct parameters for
-// XlaLocalLaunchBase constructor from OpKernelConstruction.
-std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
-  DataTypeVector constant_types;
-  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
-                        ctx->GetAttr("Tconstants", &constant_types));
-  std::vector<int> constants(constant_types.size());
-  std::iota(constants.begin(), constants.end(), 0);
-  return constants;
-}
-
-std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
-  DataTypeVector constant_types;
-  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
-                        ctx->GetAttr("Tconstants", &constant_types));
-
-  DataTypeVector arg_types;
-  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
-                        ctx->GetAttr("Targs", &arg_types));
-
-  int num_resources;
-  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
-                        ctx->GetAttr("Nresources", &num_resources));
-
-  std::vector<int> resources(num_resources);
-  std::iota(resources.begin(), resources.end(),
-            constant_types.size() + arg_types.size());
-  return resources;
-}
-
-NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
-  const NameAttrList* func;
-  OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
-  return *func;
-}
-
-#undef OP_REQUIRES_OK_RETURN
-}  // namespace
-
-XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
-    : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
-                         FunctionAttr(ctx)) {}
-
-XlaLocalLaunchOp::~XlaLocalLaunchOp() {
-  VLOG(1) << "XlaLocalLaunchOp destroyed";
-}
-
-REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
-
-REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
-                            .Device(DEVICE_GPU)
-                            .HostMemory("constants")
-                            .HostMemory("resources"),
-                        XlaLocalLaunchOp);
-
-}  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
deleted file mode 100644
index e0f10e9..0000000
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
-#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
-
-#include "tensorflow/compiler/jit/xla_compilation_cache.h"
-#include "tensorflow/compiler/jit/xla_device.h"
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-
-namespace tensorflow {
-
-// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
-// The only difference is that it does not require arguments to follow
-// the "constants, then regular args, then resources" order.
-// It takes vectors of constant and resource arguments explicitly.
-// It does not have corresponding OpDef because it is never present
-// in the GraphDef.
-// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
-// this kernel when asked to create a kernel for an XLA-compiled function.
-class XlaLocalLaunchBase : public OpKernel {
- public:
-  XlaLocalLaunchBase(OpKernelConstruction* ctx,
-                     const std::vector<int>& constants,
-                     const std::vector<int>& resources,
-                     const NameAttrList& function);
-  XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
-  XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
-  ~XlaLocalLaunchBase() override = default;
-
-  void Compute(OpKernelContext* ctx) override;
-
- protected:
-  // Builds a XlaCompilationCache class suitable for the current device.
-  Status BuildCompilationCache(OpKernelContext* ctx,
-                               XlaCompilationCache** cache);
-
-  // Indexes of compile-time constant inputs
-  std::vector<int> constants_;
-  // Indexes of resource inputs
-  std::vector<int> resources_;
-
-  DeviceType device_type_;
-  NameAttrList function_;
-  se::Platform::Id platform_id_ = nullptr;
-  bool use_multiple_streams_ = false;
-  const XlaDevice::Metadata* xla_device_metadata_ = nullptr;
-};
-
-// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
-// which will be compiled and executed using XLA.  The XlaLocalLaunchOp is
-// responsible for handling interactions with the TensorFlow executor.
-// Once all inputs are present, and their shapes are known, the op can
-// use a 'XlaCompilationCache' to compile and execute code which is specific
-// to the shapes of input Tensors.
-// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
-// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
-// memory.
-class XlaLocalLaunchOp : public XlaLocalLaunchBase {
- public:
-  explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
-  ~XlaLocalLaunchOp() override;
-
- private:
-  TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
-};
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
new file mode 100644
index 0000000..c483841
--- /dev/null
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -0,0 +1,488 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/variable_ops.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace tensorflow {
+
+namespace {
+
+Status PlatformInfoFromContext(OpKernelConstruction* ctx,
+                               XlaPlatformInfo* result) {
+  DeviceType device_type = ctx->device_type();
+  se::Platform::Id platform_id = nullptr;
+  const XlaDevice::Metadata* xla_device_metadata = nullptr;
+  std::unique_ptr<XlaAllocator> xla_allocator;
+  xla::DeviceMemoryAllocator* device_allocator = nullptr;
+
+  if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
+    platform_id = se::host::kHostPlatformId;
+  } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
+    platform_id = ctx->device()
+                      ->tensorflow_gpu_device_info()
+                      ->stream->parent()
+                      ->platform()
+                      ->id();
+  } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
+    // If we are on an XlaDevice, use the underlying XLA platform's allocator
+    // directly. We could use the StreamExecutor's allocator which may
+    // theoretically be more correct, but XLA returns a nice OOM message in a
+    // Status and StreamExecutor does not.
+    //
+    // Importantly we can't use ctx->device()->GetAllocator() as the allocator
+    // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
+    // allocator that returns XlaTensor objects. The XlaCompiler needs a real
+    // allocator to allocate real buffers.
+
+    platform_id = xla_device_metadata->platform()->id();
+    device_allocator =
+        xla_device_metadata->client()->backend().memory_allocator();
+  }
+
+  if (!device_allocator) {
+    TF_ASSIGN_OR_RETURN(se::Platform* const platform,
+                        se::MultiPlatformManager::PlatformWithId(platform_id));
+    xla_allocator = absl::make_unique<XlaAllocator>(
+        platform, ctx->device()->GetAllocator({}));
+  }
+
+  *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
+                            std::move(xla_allocator), device_allocator);
+
+  return Status::OK();
+}
+
+// A closure describing how to run a compiled version of a TensorFlow function.
+//
+// It may seem unusual to stick the resource variable snapshots in this class.
+// This is necessary: we need to use the snapshots observed by the compiler as
+// the initial values for the resource variables (and cannot snapshot them again
+// during execution) because otherwise we risk observing a different snapshot
+// with shapes different from what we compiled for.
+class XlaExecutableClosure {
+ public:
+  explicit XlaExecutableClosure(
+      xla::LocalClient* client, xla::LocalExecutable* executable,
+      const XlaCompiler::CompilationResult* compilation_result,
+      std::map<int, OptionalTensor> resource_var_snapshots)
+      : client_(client),
+        executable_(executable),
+        compilation_result_(compilation_result),
+        resource_var_snapshots_(std::move(resource_var_snapshots)) {}
+
+  XlaExecutableClosure(XlaExecutableClosure&&) = default;
+  XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
+
+  xla::LocalClient* client() const { return client_; }
+  xla::LocalExecutable* executable() const { return executable_; }
+  const XlaCompiler::CompilationResult* compilation_result() const {
+    return compilation_result_;
+  }
+  const std::map<int, OptionalTensor>& resource_var_snapshots() const {
+    return resource_var_snapshots_;
+  }
+
+ private:
+  xla::LocalClient* client_;
+  xla::LocalExecutable* executable_;
+  const XlaCompiler::CompilationResult* compilation_result_;
+  std::map<int, OptionalTensor> resource_var_snapshots_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
+};
+
+// This maintains a mapping from a globally unique ID to XlaExecutableClosure
+// instances.
+class XlaExecutableClosureStore {
+ public:
+  XlaExecutableClosureStore() : key_counter_(0) {}
+
+  using KeyT = string;
+
+  KeyT Produce(XlaExecutableClosure result) {
+    mutex_lock l(mutex_);
+    KeyT key = absl::StrCat(key_counter_++);
+    bool insert_successful = closures_.emplace(key, std::move(result)).second;
+    DCHECK(insert_successful);
+    (void)insert_successful;
+    return key;
+  }
+
+  XlaExecutableClosure Consume(const KeyT& key) {
+    mutex_lock l(mutex_);
+    auto it = closures_.find(key);
+    DCHECK(it != closures_.end());
+    XlaExecutableClosure value = std::move(it->second);
+    closures_.erase(it);
+    return value;
+  }
+
+  static XlaExecutableClosureStore* Global() {
+    static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
+    return instance;
+  }
+
+ private:
+  mutex mutex_;
+  int64 key_counter_ GUARDED_BY(mutex_);
+  gtl::FlatMap<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
+
+  TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
+};
+
+}  // namespace
+
+XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
+                                       const std::vector<int>& constants,
+                                       const std::vector<int>& resources,
+                                       const NameAttrList& function)
+    : OpKernel(ctx),
+      constants_(constants),
+      resources_(resources),
+      function_(function) {
+  OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+static Status BuildCompilationCache(OpKernelContext* ctx,
+                                    const XlaPlatformInfo& platform_info,
+                                    XlaCompilationCache** cache) {
+  if (platform_info.xla_device_metadata()) {
+    *cache = new XlaCompilationCache(
+        platform_info.xla_device_metadata()->client(),
+        platform_info.xla_device_metadata()->jit_device_type());
+    return Status::OK();
+  }
+
+  auto platform =
+      se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
+  if (!platform.ok()) {
+    return platform.status();
+  }
+  xla::LocalClientOptions client_options;
+  client_options.set_platform(platform.ValueOrDie());
+  client_options.set_intra_op_parallelism_threads(
+      ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
+  auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
+  if (!client.ok()) {
+    return client.status();
+  }
+  const XlaOpRegistry::DeviceRegistration* registration;
+  if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
+                                           &registration)) {
+    return errors::InvalidArgument("No JIT device registered for ",
+                                   platform_info.device_type().type());
+  }
+  *cache = new XlaCompilationCache(
+      client.ValueOrDie(), DeviceType(registration->compilation_device_name));
+  return Status::OK();
+}
+
+static Status CompileToLocalExecutable(
+    OpKernelContext* ctx, const NameAttrList& function,
+    const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
+    absl::Span<const int> constants, xla::LocalClient** client,
+    std::map<int, OptionalTensor>* variables,
+    const XlaCompiler::CompilationResult** kernel,
+    xla::LocalExecutable** executable) {
+  // We store information about the JIT-compiled XLA computation
+  // in the ResourceMgr.
+  ResourceMgr* rm = ctx->resource_manager();
+  if (!rm) {
+    return errors::Internal("No resource manager.");
+  }
+
+  XlaCompilationCache* cache;
+  TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
+      rm->default_container(), "xla_cache", &cache,
+      [&](XlaCompilationCache** cache) {
+        return BuildCompilationCache(ctx, platform_info, cache);
+      }));
+  // Hold the reference to the JIT during evaluation. (We could probably
+  // free it sooner because the ResourceMgr will retain a reference, but
+  // this is more obviously correct.)
+  core::ScopedUnref cache_ref(cache);
+
+  *variables = SnapshotResourceVariables(ctx, resources);
+  *client = static_cast<xla::LocalClient*>(cache->client());
+
+  XlaCompiler::Options options;
+  options.client = *client;
+  if (ctx->op_device_context() != nullptr) {
+    options.device_ordinal =
+        ctx->op_device_context()->stream()->parent()->device_ordinal();
+  }
+  options.device_type = cache->device_type();
+  options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+  options.graph_def_version = ctx->function_library()->graph_def_version();
+  options.allow_cpu_custom_calls =
+      (platform_info.platform_id() == se::host::kHostPlatformId);
+  options.device_allocator = platform_info.allocator();
+  if (platform_info.xla_device_metadata()) {
+    options.shape_representation_fn =
+        platform_info.xla_device_metadata()->shape_representation_fn();
+  }
+
+  std::map<int, Tensor> constant_args;
+  for (int i : constants) {
+    constant_args.insert({i, ctx->input(i)});
+  }
+  XlaCompiler::CompileOptions compile_options;
+  compile_options.is_entry_computation = true;
+  // If we resolve constants we never emit them on the device, meaning that if
+  // they are needed by a following computation the host has to transfer
+  // them. Not resolving constants is expected to be faster than resolving
+  // constants.
+  compile_options.resolve_compile_time_constants = true;
+  // Optimization: where possible, have the computation return a naked array
+  // rather than a one-element tuple.
+  compile_options.always_return_tuple = false;
+
+  return cache->Compile(options, function, constant_args, *variables, ctx,
+                        kernel, executable, compile_options);
+}
+
+void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
+  VLOG(1) << "XlaLocalLaunchOpBase::Compute "
+          << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
+
+  xla::LocalClient* client;
+  const XlaCompiler::CompilationResult* kernel;
+  xla::LocalExecutable* executable;
+  std::map<int, OptionalTensor> variables;
+
+  OP_REQUIRES_OK(
+      ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
+                                    constants_, &client, &variables, &kernel,
+                                    &executable));
+
+  se::Stream* stream =
+      ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+
+  VLOG(1) << "Executing XLA Computation...";
+
+  XlaComputationLaunchContext launch_context(
+      client, platform_info_.allocator(),
+      /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
+      platform_info_.UseMultipleStreams());
+  launch_context.PopulateInputs(ctx, kernel, variables);
+
+  // Execute the computation.
+  VLOG(2) << "Executing computation.";
+  xla::ExecutableRunOptions run_options;
+  run_options.set_stream(stream);
+  run_options.set_allocator(platform_info_.allocator());
+  run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+  run_options.set_rng_seed(GetXLARandomSeed());
+  Env* env = Env::Default();
+  auto start_time = env->NowMicros();
+
+  auto run_result = executable->Run(launch_context.arguments(), run_options);
+  OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+  auto elapsed = env->NowMicros() - start_time;
+  VLOG(2) << "Elapsed time: " << elapsed << "us";
+
+  OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
+                          ctx, kernel, run_result.ConsumeValueOrDie()));
+  VLOG(1) << "Done";
+}
+
+namespace {
+
+// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
+// in error case, it returns RET instead of void.
+#define OP_REQUIRES_OK_RETURN(CTX, RET, ...)                \
+  do {                                                      \
+    ::tensorflow::Status _s(__VA_ARGS__);                   \
+    if (!TF_PREDICT_TRUE(_s.ok())) {                        \
+      (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+      return RET;                                           \
+    }                                                       \
+  } while (0)
+
+// Helper static functions to construct parameters for
+// XlaLocalLaunchBase constructor from OpKernelConstruction.
+std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
+  DataTypeVector constant_types;
+  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+                        ctx->GetAttr("Tconstants", &constant_types));
+  std::vector<int> constants(constant_types.size());
+  std::iota(constants.begin(), constants.end(), 0);
+  return constants;
+}
+
+std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
+  DataTypeVector constant_types;
+  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+                        ctx->GetAttr("Tconstants", &constant_types));
+
+  DataTypeVector arg_types;
+  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+                        ctx->GetAttr("Targs", &arg_types));
+
+  int num_resources;
+  OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+                        ctx->GetAttr("Nresources", &num_resources));
+
+  std::vector<int> resources(num_resources);
+  std::iota(resources.begin(), resources.end(),
+            constant_types.size() + arg_types.size());
+  return resources;
+}
+
+NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
+  const NameAttrList* func;
+  OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
+  return *func;
+}
+
+#undef OP_REQUIRES_OK_RETURN
+}  // namespace
+
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+    : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
+                         FunctionAttr(ctx)) {}
+
+XlaLocalLaunchOp::~XlaLocalLaunchOp() {
+  VLOG(1) << "XlaLocalLaunchOp destroyed";
+}
+
+XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
+    : OpKernel(ctx),
+      constants_(ConstantsVector(ctx)),
+      resources_(ResourcesVector(ctx)),
+      function_(FunctionAttr(ctx)) {
+  OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+void XlaCompileOp::Compute(OpKernelContext* ctx) {
+  xla::LocalClient* client;
+  const XlaCompiler::CompilationResult* kernel;
+  xla::LocalExecutable* executable;
+  std::map<int, OptionalTensor> variables;
+
+  OP_REQUIRES_OK(
+      ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
+                                    constants_, &client, &variables, &kernel,
+                                    &executable));
+
+  // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
+  // if it didn't have to compile the cluster because of a compilation-cache
+  // hit.  This is because we at least need new snapshots of the resource
+  // variables.
+  XlaExecutableClosureStore::KeyT key =
+      XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
+          client, executable, kernel, std::move(variables)));
+
+  Allocator* cpu_allocator = [&] {
+    AllocatorAttributes host_alloc_attrs;
+    host_alloc_attrs.set_gpu_compatible(true);
+    host_alloc_attrs.set_on_host(true);
+    return ctx->device()->GetAllocator(host_alloc_attrs);
+  }();
+
+  Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
+  compilation_key.flat<string>()(0) = key;
+
+  Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
+  compilation_successful.flat<bool>()(0) = true;
+
+  ctx->set_output(0, compilation_key);
+  ctx->set_output(1, compilation_successful);
+}
+
+XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+  OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+void XlaRunOp::Compute(OpKernelContext* ctx) {
+  Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
+  const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0);
+
+  XlaExecutableClosure closure =
+      XlaExecutableClosureStore::Global()->Consume(key);
+
+  XlaComputationLaunchContext launch_context(
+      closure.client(), platform_info_.allocator(),
+      /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
+      /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
+  launch_context.PopulateInputs(ctx, closure.compilation_result(),
+                                closure.resource_var_snapshots());
+
+  se::Stream* stream =
+      ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+  xla::ExecutableRunOptions run_options;
+  run_options.set_stream(stream);
+  run_options.set_allocator(platform_info_.allocator());
+  run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+  run_options.set_rng_seed(GetXLARandomSeed());
+  Env* env = Env::Default();
+  auto start_time = env->NowMicros();
+
+  auto run_result =
+      closure.executable()->Run(launch_context.arguments(), run_options);
+  OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+  auto elapsed = env->NowMicros() - start_time;
+  VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
+
+  OP_REQUIRES_OK(
+      ctx, launch_context.PopulateOutputs(ctx, closure.compilation_result(),
+                                          run_result.ConsumeValueOrDie()));
+}
+
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
+
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("constants")
+                            .HostMemory("resources"),
+                        XlaLocalLaunchOp);
+
+REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
+REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("constants")
+                            .HostMemory("resources"),
+                        XlaCompileOp);
+
+REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
+
+REGISTER_KERNEL_BUILDER(
+    Name("_XlaRun").Device(DEVICE_GPU).HostMemory("constants"), XlaRunOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h
new file mode 100644
index 0000000..489d26e
--- /dev/null
+++ b/tensorflow/compiler/jit/kernels/xla_ops.h
@@ -0,0 +1,168 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
+#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
+
+#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace tensorflow {
+
+// Holds some information about the platform on which an
+// XlaLaunch/_XlaCompile/_XlaRun op must run on.
+class XlaPlatformInfo {
+ public:
+  XlaPlatformInfo() : device_type_("") {}
+  explicit XlaPlatformInfo(const DeviceType device_type,
+                           se::Platform::Id platform_id,
+                           const XlaDevice::Metadata* xla_device_metadata,
+                           std::unique_ptr<XlaAllocator> xla_allocator,
+                           xla::DeviceMemoryAllocator* device_allocator)
+      : device_type_(device_type),
+        platform_id_(platform_id),
+        xla_device_metadata_(xla_device_metadata),
+        xla_allocator_(std::move(xla_allocator)),
+        device_allocator_(device_allocator) {
+    CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr));
+  }
+
+  XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
+
+  bool UseMultipleStreams() const {
+    return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
+  }
+
+  xla::DeviceMemoryAllocator* allocator() const {
+    return device_allocator_ ? device_allocator_ : xla_allocator_.get();
+  }
+  DeviceType device_type() const { return device_type_; }
+
+  // This is equal to xla_device_metadata()->platform()->id() if
+  // xla_device_metadata() is not nullptr.
+  se::Platform::Id platform_id() const { return platform_id_; }
+
+  // This may be null if the op this XlaPlatformInfo is for was not placed on an
+  // XLA device.
+  const XlaDevice::Metadata* xla_device_metadata() const {
+    return xla_device_metadata_;
+  }
+  bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
+
+ private:
+  DeviceType device_type_;
+  se::Platform::Id platform_id_;
+
+  // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
+  // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
+  // XlaLaunch/_XlaCompile/_XlaRun OpKernel.
+  const XlaDevice::Metadata* xla_device_metadata_;
+
+  // If the op associated with this XlaPlatformInfo is placed on an XLA device
+  // then device_allocator_ is the xla::Backend's memory allocator and
+  // xla_allocator_ is null.  If the op is placed on a regular CPU or GPU device
+  // then device_allocator_ is null and xla_allocator_ points to an appropriate
+  // XlaAllocator instance.
+  std::unique_ptr<XlaAllocator> xla_allocator_;
+  xla::DeviceMemoryAllocator* device_allocator_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
+};
+
+// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
+// The only difference is that it does not require arguments to follow
+// the "constants, then regular args, then resources" order.
+// It takes vectors of constant and resource arguments explicitly.
+// It does not have corresponding OpDef because it is never present
+// in the GraphDef.
+// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
+// this kernel when asked to create a kernel for an XLA-compiled function.
+class XlaLocalLaunchBase : public OpKernel {
+ public:
+  XlaLocalLaunchBase(OpKernelConstruction* ctx,
+                     const std::vector<int>& constants,
+                     const std::vector<int>& resources,
+                     const NameAttrList& function);
+  XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
+  XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
+  ~XlaLocalLaunchBase() override = default;
+
+  void Compute(OpKernelContext* ctx) override;
+
+ protected:
+  // Indexes of compile-time constant inputs
+  std::vector<int> constants_;
+  // Indexes of resource inputs
+  std::vector<int> resources_;
+
+  NameAttrList function_;
+  XlaPlatformInfo platform_info_;
+};
+
+// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
+// which will be compiled and executed using XLA.  The XlaLocalLaunchOp is
+// responsible for handling interactions with the TensorFlow executor.
+// Once all inputs are present, and their shapes are known, the op can
+// use a 'XlaCompilationCache' to compile and execute code which is specific
+// to the shapes of input Tensors.
+// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
+// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
+// memory.
+class XlaLocalLaunchOp : public XlaLocalLaunchBase {
+ public:
+  explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
+  ~XlaLocalLaunchOp() override;
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
+};
+
+class XlaCompileOp : public OpKernel {
+ public:
+  explicit XlaCompileOp(OpKernelConstruction* ctx);
+
+  void Compute(OpKernelContext* ctx) override;
+
+ private:
+  // Indexes of compile-time constant inputs
+  std::vector<int> constants_;
+  // Indexes of resource inputs
+  std::vector<int> resources_;
+
+  NameAttrList function_;
+
+  XlaPlatformInfo platform_info_;
+};
+
+class XlaRunOp : public OpKernel {
+ public:
+  explicit XlaRunOp(OpKernelConstruction* ctx);
+
+  void Compute(OpKernelContext* ctx) override;
+
+ private:
+  XlaPlatformInfo platform_info_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index e6cc6e5..133d982 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -365,10 +365,13 @@
   return elementwise_ops->count(node.op()) > 0;
 }
 
+// Nodes that XLA can compile are put in `candidates`.  Nodes put in
+// `isolated_nodes` must either be unclustered or be put in trivial single-node
+// clusters.
 Status FindCompilationCandidates(
     const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
     const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
-    OrderedNodeSet* candidates) {
+    OrderedNodeSet* candidates, gtl::FlatSet<Node*>* isolated_nodes) {
   OptimizerOptions opts;
   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
       new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
@@ -411,6 +414,8 @@
     DeviceType device_type("");
     TF_RETURN_IF_ERROR(
         DeviceToDeviceType(node->assigned_device_name(), &device_type));
+    VLOG(4) << "Device type for " << node->name() << ": "
+            << device_type.type_string();
 
     if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
       // is_compilable_fn has already logged the reason if it returned false.
@@ -439,19 +444,56 @@
               << node->type_string();
       continue;
     }
-    if (compile_time_const_nodes[node->id()] &&
-        !registration->requires_compilation) {
+    if (compile_time_const_nodes[node->id()]) {
       const OpDef* op_def;
       TF_RETURN_IF_ERROR(
           graph.op_registry()->LookUpOpDef(node->type_string(), &op_def));
       if (op_def->is_stateful()) {
-        // We need to be able to constant fold the nodes in
-        // compile_time_const_nodes given constant inputs (required by XLA) and
-        // therefore can't auto-cluster stateful ops since these can never be
-        // constant folded.
-        VLOG(2) << "Rejecting " << node->name()
-                << ": must-be-constant stateful op";
-        continue;
+        // It is easiest to demonstrate the problem we're trying to solve with
+        // an example.  Say we have this graph:
+        //
+        //   shape = RandomUniformInt();
+        //   reshape = Reshape(input, shape)
+        //
+        // Both RandomUniformInt and Reshape are compilable by XLA so, absent
+        // any other reason, we will try to put both shape and reshape in the
+        // same cluster.  However, since XLA only supports statically shaped
+        // values, it will expect to be able to constant fold `shape` to get a
+        // static shape for `reshape`.  This is a problem because side-effecting
+        // ops like RandomUniformInt() cannot be constant folded.  We fix this
+        // by putting `shape` and `reshape` in different clusters, which results
+        // in us recompiling `reshape`'s cluster for every new value of `shape`,
+        // making `reshape` statically sized within each compilation.  We
+        // simplify the solution even further by disallowing operations like
+        // `shape` from being part of *any* non-trivial cluster.  They're either
+        // not compiled by XLA altogether or, if assigned to an XLA_* device
+        // with "must compile" semantics, compiled into a trivial single-op
+        // cluster.  This approach leaves some room for improvement, and we can
+        // consider implementing a more aggressive data-flow-analysis based
+        // solution in the future if needed.
+        //
+        // One ugly problem we have to contend with: certain sets of ops *have*
+        // to be in the same cluster because values flowing between them have
+        // types that can't be live-in or live-out of a cluster.  These ops are:
+        //
+        //  - TensorArray ops operating on the same TensorArray instance.
+        //  - Stack ops operating on the same Stack instance.
+        //
+        // To work around this we avoid isolating these specific ops.  Because
+        // of this concession it is unsound to auto-cluster them because then
+        // we'd create clusters we could not compile (because we can't constant
+        // fold, say, a TensorArrayRead or a StackPopV2).  But we don't
+        // auto-cluster these operations today so we're good for now.
+        const XlaResourceOpInfo* op_info =
+            GetResourceOpInfoForOp(node->type_string());
+        bool is_tensor_array_or_stack_op =
+            op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
+        if (!is_tensor_array_or_stack_op) {
+          VLOG(2) << "Isolating " << node->name()
+                  << ": must-be-constant stateful op";
+          isolated_nodes->insert(node);
+          // Keep going and execute all the other checks.
+        }
       }
     }
     // We don't auto-cluster functional control flow nodes containing resource
@@ -807,11 +849,12 @@
   Graph* graph = options.graph->get();
 
   OrderedNodeSet compilation_candidates;
+  gtl::FlatSet<Node*> isolated_nodes;
   TF_RETURN_IF_ERROR(FindCompilationCandidates(
       *graph, options.flib_def,
       (options.session_options != nullptr) ? options.session_options->env
                                            : Env::Default(),
-      is_compilable_fn, &compilation_candidates));
+      is_compilable_fn, &compilation_candidates, &isolated_nodes));
 
   if (compilation_candidates.empty()) {
     VLOG(2) << "No compilable candidates";
@@ -856,6 +899,11 @@
           "Found control flow node in clustering worklist: ",
           node_from->type_string());
     }
+
+    if (isolated_nodes.count(node_from)) {
+      continue;
+    }
+
     string from_scope;
     string to_scope;
     for (int to : cycles.Successors(from)) {
@@ -873,6 +921,9 @@
           node_to->assigned_device_name()) {
         continue;
       }
+      if (isolated_nodes.count(node_to)) {
+        continue;
+      }
       // Look for an _XlaScope on both nodes.  If both nodes have a
       // scope and the scopes do not match, do not cluster along this
       // edge. This restriction is overridden if the global_jit_level is ON. If
@@ -931,6 +982,11 @@
   // Names for each cluster.
   std::unordered_map<int, string> cluster_names;
 
+  if (flags->tf_xla_clustering_debug) {
+    dump_graph::DumpGraphToFile("before_mark_for_compilation", **options.graph,
+                                options.flib_def);
+  }
+
   // Mark clusters for compilation that:
   // * are placed on a device that requires compilation (an XlaDevice),
   // * are explicitly marked for compilation (_XlaCompile=true), or
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index c59770a..4f9145b 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -894,5 +894,71 @@
   EXPECT_EQ(clusters["fn_call"], "");
 }
 
+TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
+  absl::string_view xla_gpu_device =
+      "/job:worker/replica:0/task:0/device:XLA_GPU:0";
+
+  Scope root = Scope::NewRootScope().ExitOnError();
+  Output shape_shape =
+      ops::Const(root.WithOpName("test/shape_shape"), {2}, {1});
+  Output shape =
+      ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape,
+                            ops::Const(root.WithOpName("test/minval"), 1),
+                            ops::Const(root.WithOpName("test/maxval"), 20));
+  Output reshape_input =
+      ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT,
+                       ops::Placeholder::Shape(TensorShape({500, 500})));
+  Output reshape =
+      ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape);
+
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+  TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+  for (Node* n : graph->nodes()) {
+    if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+      n->set_assigned_device_name(string(xla_gpu_device));
+    }
+  }
+  TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+  std::unordered_map<string, string> clusters = GetClusters(*graph);
+  EXPECT_NE(clusters["test/shape_rng"], "");
+  EXPECT_NE(clusters["test/reshape"], "");
+  EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]);
+}
+
+TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
+  absl::string_view xla_gpu_device =
+      "/job:worker/replica:0/task:0/device:XLA_GPU:0";
+  Scope root = Scope::NewRootScope().ExitOnError();
+  ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1,
+                                DT_INT32);
+  Output zero = ops::Const(root.WithOpName("test/zero"), 0);
+  ops::TensorArrayWrite tensor_array_write(
+      root.WithOpName("test/write"), tensor_array.handle, zero,
+      ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow);
+  Output tensor_array_read =
+      ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle,
+                           zero, tensor_array_write.flow_out, DT_INT32);
+  Output reshape =
+      ops::Reshape(root.WithOpName("test/reshape"),
+                   ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT),
+                   tensor_array_read);
+
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+  TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+  for (Node* n : graph->nodes()) {
+    if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+      n->set_assigned_device_name(string(xla_gpu_device));
+    }
+  }
+  TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+  std::unordered_map<string, string> clusters = GetClusters(*graph);
+  EXPECT_NE(clusters["test/read"], "");
+  EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]);
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
index 6566987..d56d0f8 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
@@ -14,18 +14,35 @@
 ==============================================================================*/
 
 #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/public/session_options.h"
 
 namespace tensorflow {
 /*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
     std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
     SessionOptions* session_options) {
-  // Assign all nodes to the CPU device.
+  // Assign all unassigned nodes to the CPU device.
   static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
   for (Node* n : (*graph)->nodes()) {
-    n->set_assigned_device_name(kCpuDevice);
+    if (n->assigned_device_name().empty()) {
+      n->set_assigned_device_name(kCpuDevice);
+    }
   }
 
+  // Call AddDevices to register the XLA devices.
+  //
+  // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
+  // make this more direct, but probably not worth it solely for this test.
+  std::vector<Device*> devices;
+  TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
+
+  auto delete_devices = gtl::MakeCleanup([&] {
+    for (Device* d : devices) {
+      delete d;
+    }
+  });
+
   GraphOptimizationPassOptions opt_options;
   opt_options.graph = graph;
   opt_options.session_options = session_options;
diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc
new file mode 100644
index 0000000..d8ace62
--- /dev/null
+++ b/tensorflow/compiler/jit/node_matchers.cc
@@ -0,0 +1,458 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/node_matchers.h"
+
+#include <utility>
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+
+namespace tensorflow {
+namespace testing {
+namespace matchers {
+namespace {
+
+using impl::NodeMatcherProperties;
+
+string IndentAllButFirstLine(absl::string_view text) {
+  std::vector<std::string> lines = absl::StrSplit(text, '\n');
+  for (int i = 1; i < lines.size(); i++) {
+    lines[i].insert(0, "  ");
+  }
+  return absl::StrJoin(lines, "\n");
+}
+
+template <typename T>
+bool CompareTensor(const Tensor& actual, const Tensor& expected,
+                   ::testing::MatchResultListener* listener) {
+  if (actual.NumElements() != expected.NumElements()) {
+    if (listener->IsInterested()) {
+      *listener << "\nwas looking for tensor with " << expected.NumElements()
+                << " elements, found tensor with " << actual.NumElements()
+                << " elements";
+      return false;
+    }
+  }
+
+  for (int64 i = 0, e = actual.NumElements(); i < e; i++) {
+    if (actual.flat<T>()(i) != expected.flat<T>()(i)) {
+      *listener << "\nmismatch in constant tensor at index " << i
+                << " expected = " << expected.flat<T>()(i)
+                << " actual = " << actual.flat<T>()(i);
+      return false;
+    }
+  }
+
+  return true;
+}
+
+bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor,
+                           ::testing::MatchResultListener* listener) {
+  if (tensor.dtype() != expected_tensor.dtype()) {
+    if (listener->IsInterested()) {
+      *listener << "\nexpected tensor of type "
+                << DataType_Name(expected_tensor.dtype())
+                << " but found one of type " << DataType_Name(tensor.dtype());
+      return false;
+    }
+  }
+
+  switch (tensor.dtype()) {
+    case DT_FLOAT:
+      return CompareTensor<float>(tensor, expected_tensor, listener);
+    case DT_DOUBLE:
+      return CompareTensor<double>(tensor, expected_tensor, listener);
+    case DT_INT8:
+      return CompareTensor<int8>(tensor, expected_tensor, listener);
+    case DT_INT16:
+      return CompareTensor<int16>(tensor, expected_tensor, listener);
+    case DT_INT32:
+      return CompareTensor<int32>(tensor, expected_tensor, listener);
+    case DT_INT64:
+      return CompareTensor<int64>(tensor, expected_tensor, listener);
+    case DT_UINT8:
+      return CompareTensor<uint8>(tensor, expected_tensor, listener);
+    case DT_UINT16:
+      return CompareTensor<uint16>(tensor, expected_tensor, listener);
+    case DT_UINT32:
+      return CompareTensor<uint32>(tensor, expected_tensor, listener);
+    case DT_UINT64:
+      return CompareTensor<uint64>(tensor, expected_tensor, listener);
+    default:
+      LOG(FATAL) << "Unsupported dtype "  // Crash ok: testonly.
+                 << DataType_Name(tensor.dtype());
+  }
+}
+
+using Input = std::pair<const Node*, int>;
+
+struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
+  bool MatchAndExplain(
+      const Node* node,
+      ::testing::MatchResultListener* listener) const override {
+    if (op && node->type_string() != *op) {
+      if (listener->IsInterested()) {
+        *listener << "\nexpected op " << *op << " but found "
+                  << node->type_string();
+      }
+      return false;
+    }
+
+    if (assigned_device && node->assigned_device_name() != *assigned_device) {
+      if (listener->IsInterested()) {
+        *listener << "\nexpected assigned_device " << *assigned_device
+                  << " but found \"" << node->assigned_device_name() << "\"";
+      }
+      return false;
+    }
+
+    if (name && node->name() != *name) {
+      if (listener->IsInterested()) {
+        *listener << "\nexpected name " << *name << " but found "
+                  << node->name();
+      }
+      return false;
+    }
+
+    if (constant_value) {
+      const TensorProto* proto = nullptr;
+      if (!GetNodeAttr(node->def(), "value", &proto).ok()) {
+        if (listener->IsInterested()) {
+          *listener << "\ncould not find \"value\" attribute in node";
+        }
+        return false;
+      }
+
+      Tensor tensor(proto->dtype());
+      if (!tensor.FromProto(*proto)) {
+        if (listener->IsInterested()) {
+          *listener << "\ncould not convert TensorProto in \"value\" attribute "
+                       "to Tensor";
+        }
+        return false;
+      }
+
+      if (!MatchAndExplainTensor(/*tensor=*/tensor,
+                                 /*expected_tensor=*/*constant_value,
+                                 listener)) {
+        return false;
+      }
+    }
+
+    if (input_matchers) {
+      if (input_matchers->size() != node->num_inputs()) {
+        if (listener->IsInterested()) {
+          *listener << "\nexpected " << input_matchers->size()
+                    << " inputs but node has " << node->num_inputs();
+        }
+        return false;
+      }
+
+      for (int input_idx = 0, e = input_matchers->size(); input_idx < e;
+           input_idx++) {
+        if (!MatchAndExplainInput(node, input_idx, listener)) {
+          return false;
+        }
+      }
+    }
+
+    std::vector<const Node*> control_deps;
+    for (const Edge* e : node->in_edges()) {
+      if (e->IsControlEdge()) {
+        control_deps.push_back(e->src());
+      }
+    }
+
+    ::testing::StringMatchResultListener inner_listener;
+    if (control_dep_set &&
+        !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) {
+      if (listener->IsInterested()) {
+        string explanation = inner_listener.str();
+        if (!explanation.empty()) {
+          explanation = absl::StrCat(", ", explanation, ",");
+        }
+        *listener << "ctrl_deps" << explanation << " does not match expected: ";
+        control_dep_set->DescribeTo(listener->stream());
+      }
+      return false;
+    }
+    return true;
+  }
+
+  void DescribeTo(::std::ostream* os) const override {
+    std::vector<string> predicates;
+
+    if (name) {
+      predicates.push_back(absl::StrCat("name: ", *name));
+    }
+
+    if (op) {
+      predicates.push_back(absl::StrCat("op: ", *op));
+    }
+
+    if (assigned_device) {
+      predicates.push_back(absl::StrCat("assigned device: ", *assigned_device));
+    }
+
+    bool printed_something = !predicates.empty();
+
+    *os << absl::StrJoin(predicates, ", ");
+
+    if (constant_value) {
+      printed_something = true;
+      *os << "constant value: " << constant_value->DebugString();
+    }
+
+    if (input_matchers) {
+      if (!input_matchers->empty()) {
+        printed_something = true;
+        *os << " with " << (input_matchers->size() == 1 ? "only " : "")
+            << "input" << (input_matchers->size() == 1 ? "" : "s") << " ";
+      }
+
+      if (input_matchers->size() == 1) {
+        ::std::stringstream ss;
+        input_matchers->front().DescribeTo(&ss);
+        printed_something = true;
+        *os << "matching " << ss.str();
+      } else {
+        int edge_idx = 0;
+        for (const ::testing::Matcher<Input>& matcher : (*input_matchers)) {
+          *os << "\n  [" << edge_idx << "] matching (";
+          ::std::stringstream ss;
+          matcher.DescribeTo(&ss);
+          printed_something = true;
+          *os << IndentAllButFirstLine(ss.str());
+          *os << ")";
+          edge_idx++;
+        }
+      }
+    }
+
+    if (control_dep_set) {
+      printed_something = true;
+      *os << " and control deps ";
+      control_dep_set->DescribeTo(os);
+    }
+
+    if (!printed_something) {
+      *os << "is any node";
+    }
+  }
+
+  bool MatchAndExplainInput(const Node* node, int input_idx,
+                            ::testing::MatchResultListener* listener) const {
+    const Edge* edge;
+    if (!node->input_edge(input_idx, &edge).ok()) {
+      if (listener->IsInterested()) {
+        *listener << "\ncould not find incoming edge for input " << input_idx;
+      }
+      return false;
+    }
+
+    ::testing::StringMatchResultListener inner_listener;
+    Input input = {edge->src(), edge->src_output()};
+    if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) {
+      return true;
+    }
+
+    if (listener->IsInterested()) {
+      *listener << "\ninput " << input_idx << " does not match expected:\n";
+      (*input_matchers)[input_idx].DescribeTo(listener->stream());
+      string explanation = inner_listener.str();
+      if (!explanation.empty()) {
+        *listener << ", " << explanation;
+      }
+    }
+    return false;
+  }
+
+  absl::optional<string> op;
+  absl::optional<string> name;
+  absl::optional<string> assigned_device;
+  absl::optional<Tensor> constant_value;
+  absl::optional<std::vector<::testing::Matcher<Input>>> input_matchers;
+  absl::optional<::testing::Matcher<absl::Span<const Node* const>>>
+      control_dep_set;
+};
+
+// Matches a dst and dst_output on an input edge.  Today we only use this with
+// dst_output=0 but we will eventually need to support multi-output operations.
+class InputMatcher : public ::testing::MatcherInterface<Input> {
+ public:
+  InputMatcher(::testing::Matcher<const Node*> src_matcher, int src_output)
+      : src_matcher_(std::move(src_matcher)), src_output_(src_output) {}
+
+  bool MatchAndExplain(
+      Input input, ::testing::MatchResultListener* listener) const override {
+    ::testing::StringMatchResultListener inner_listener;
+    if (!src_matcher_.MatchAndExplain(input.first, &inner_listener)) {
+      if (listener->IsInterested()) {
+        *listener << "\nsource does not match expected ";
+        src_matcher_.DescribeTo(listener->stream());
+        string explanation = inner_listener.str();
+        if (!explanation.empty()) {
+          *listener << "\n\t" << explanation;
+        }
+      }
+      return false;
+    }
+    if (input.second != src_output_) {
+      if (listener->IsInterested()) {
+        *listener << "\nexpected output slot to be " << src_output_
+                  << " but found " << input.second;
+      }
+      return false;
+    }
+
+    return true;
+  }
+
+  void DescribeTo(::std::ostream* os) const override {
+    if (src_output_) {
+      *os << "output slot: " << src_output_ << ", source: (";
+    }
+
+    src_matcher_.DescribeTo(os);
+
+    if (src_output_) {
+      *os << ")";
+    }
+  }
+
+ private:
+  ::testing::Matcher<const Node*> src_matcher_;
+  int src_output_;
+};
+
+std::vector<::testing::Matcher<Input>> NodeMatchersToInputMatchers(
+    absl::Span<const ::testing::Matcher<const Node*>> node_matchers) {
+  std::vector<::testing::Matcher<Input>> result;
+  absl::c_transform(node_matchers, std::back_inserter(result),
+                    [](::testing::Matcher<const Node*> n) {
+                      return ::testing::MakeMatcher(new InputMatcher(n, 0));
+                    });
+  return result;
+}
+}  // namespace
+
+::testing::Matcher<const Node*> impl::NodeWith(
+    absl::Span<const NodeMatcherProperties> props) {
+  NodeMatcher* matcher = new NodeMatcher();
+  for (const NodeMatcherProperties& prop : props) {
+    if (prop.name()) {
+      DCHECK(!matcher->name);
+      matcher->name = prop.name();
+    }
+
+    if (prop.op()) {
+      DCHECK(!matcher->op);
+      matcher->op = prop.op();
+    }
+
+    if (prop.constant_value()) {
+      DCHECK(!matcher->constant_value);
+      matcher->constant_value = prop.constant_value();
+    }
+
+    if (prop.assigned_device()) {
+      DCHECK(!matcher->assigned_device);
+      matcher->assigned_device = prop.assigned_device();
+    }
+
+    if (prop.input_nodes()) {
+      DCHECK(!matcher->input_matchers);
+      matcher->input_matchers =
+          NodeMatchersToInputMatchers(*prop.input_nodes());
+    }
+
+    if (prop.control_deps()) {
+      DCHECK(!matcher->control_dep_set);
+      matcher->control_dep_set =
+          ::testing::UnorderedElementsAreArray(*prop.control_deps());
+    }
+  }
+
+  return ::testing::MakeMatcher(matcher);
+}
+
+impl::NodeMatcherProperties Name(string name) {
+  impl::NodeMatcherProperties props;
+  props.set_name(std::move(name));
+  return props;
+}
+
+// Matches a node with op `op`.
+impl::NodeMatcherProperties Op(string op) {
+  impl::NodeMatcherProperties props;
+  props.set_op(std::move(op));
+  return props;
+}
+
+// Matches a node with assigned device `assigned_device`.
+impl::NodeMatcherProperties AssignedDevice(string assigned_device) {
+  impl::NodeMatcherProperties props;
+  props.set_assigned_device(std::move(assigned_device));
+  return props;
+}
+
+impl::NodeMatcherProperties impl::Inputs(
+    absl::Span<const ::testing::Matcher<const Node*>> inputs) {
+  std::vector<::testing::Matcher<const Node*>> inputs_vector;
+  absl::c_copy(inputs, std::back_inserter(inputs_vector));
+
+  impl::NodeMatcherProperties props;
+  props.set_input_nodes(std::move(inputs_vector));
+  return props;
+}
+
+impl::NodeMatcherProperties impl::CtrlDeps(
+    absl::Span<const ::testing::Matcher<const Node*>> control_deps) {
+  std::vector<::testing::Matcher<const Node*>> control_deps_vector;
+  absl::c_copy(control_deps, std::back_inserter(control_deps_vector));
+
+  impl::NodeMatcherProperties props;
+  props.set_control_deps(std::move(control_deps_vector));
+  return props;
+}
+
+NodeMatcherProperties ConstantValue(
+    const ::tensorflow::Input::Initializer& val) {
+  TF_CHECK_OK(val.status);
+  NodeMatcherProperties props;
+  props.set_constant_value(val.tensor);
+  return props;
+}
+
+::testing::Matcher<const Node*> Const(
+    const ::tensorflow::Input::Initializer& val) {
+  return NodeWith(ConstantValue(val));
+}
+}  // namespace matchers
+
+Node* FindNodeByName(Graph* g, absl::string_view name) {
+  for (Node* n : g->nodes()) {
+    if (n->name() == name) {
+      return n;
+    }
+  }
+
+  return nullptr;
+}
+}  // namespace testing
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/node_matchers.h b/tensorflow/compiler/jit/node_matchers.h
new file mode 100644
index 0000000..0437a7e
--- /dev/null
+++ b/tensorflow/compiler/jit/node_matchers.h
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Provides a set of matchers for tensorflow nodes.
+//
+// Example usage:
+//
+//  tensorflow::Node* node = ...;
+//  EXPECT_THAT(node, NodeWith(Name("name"), Op("op"),
+//                             Inputs(NodeWith(Name("input")))))
+//
+// Matchable node properties (the expressions that go inside NodeWith(...))
+// are:
+//
+//  - Name(string): matches the node name exactly.  We will probably need to
+//    have this take a string matcher soon in the future.
+//
+//  - Op(string): matches the op exactly.
+//
+//  - AssignedDevice(string): matches the assigned device exactly.
+//
+//  - Inputs(<ordered list>): matches the list of non-control inputs to the node
+//    exactly (i.e. does not match a suffix or a prefix).
+//
+//  - CtrlDeps(<unordered list>): matches the list of control dependences on the
+//    node exactly but in any order.
+//
+//  - ConstantValue(tensorflow::Input::Initializer init): matches a Const node
+//    with the constant value `init`.  Implies Op("Const").
+//
+// Node properties may not be repeated in a single NodeWith(...)  matcher.
+// E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail.  Since ConstantValue
+// implies Op("Const"), a single NodeWith matcher can't have both
+// ConstantValue(...) and Op(...).
+
+#ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
+#define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
+
+#include <array>
+#include <string>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+namespace testing {
+namespace matchers {
+
+namespace impl {
+
+// -----------------------------------------------------------------------------
+// Implementation details.
+
+// Properties that we match on for a particular Node.  If a particular property
+// is nullopt then any value for it is allowed.
+class NodeMatcherProperties {
+ public:
+  using NodeSeqMatcher = std::vector<::testing::Matcher<const Node*>>;
+
+  const absl::optional<string>& name() const { return name_; }
+  const absl::optional<string>& op() const { return op_; }
+  const absl::optional<string>& assigned_device() const {
+    return assigned_device_;
+  }
+  const absl::optional<Tensor>& constant_value() const {
+    return constant_value_;
+  }
+  const absl::optional<NodeSeqMatcher>& input_nodes() const {
+    return input_nodes_;
+  }
+  const absl::optional<NodeSeqMatcher>& control_deps() const {
+    return control_deps_;
+  }
+
+  void set_name(string name) {
+    DCHECK(IsEmpty());
+    name_ = std::move(name);
+  }
+
+  void set_op(string op) {
+    DCHECK(IsEmpty());
+    op_ = std::move(op);
+  }
+
+  void set_assigned_device(string assigned_device) {
+    DCHECK(IsEmpty());
+    assigned_device_ = std::move(assigned_device);
+  }
+
+  void set_constant_value(Tensor constant_value) {
+    DCHECK(IsEmpty());
+    constant_value_ = std::move(constant_value);
+    op_ = "Const";
+  }
+
+  void set_input_nodes(NodeSeqMatcher input_nodes) {
+    DCHECK(IsEmpty());
+    input_nodes_ = std::move(input_nodes);
+  }
+
+  void set_control_deps(NodeSeqMatcher control_deps) {
+    DCHECK(IsEmpty());
+    control_deps_ = std::move(control_deps);
+  }
+
+  bool IsEmpty() const {
+    return !name().has_value() && !op().has_value() &&
+           !input_nodes().has_value() && !control_deps().has_value();
+  }
+
+ private:
+  absl::optional<string> name_;
+  absl::optional<string> op_;
+  absl::optional<string> assigned_device_;
+  absl::optional<Tensor> constant_value_;
+  absl::optional<NodeSeqMatcher> input_nodes_;
+  absl::optional<NodeSeqMatcher> control_deps_;
+};
+
+::testing::Matcher<const Node*> NodeWith(
+    absl::Span<const NodeMatcherProperties> props);
+
+impl::NodeMatcherProperties Inputs(
+    absl::Span<const ::testing::Matcher<const Node*>> inputs);
+
+impl::NodeMatcherProperties CtrlDeps(
+    absl::Span<const ::testing::Matcher<const Node*>> control_deps);
+}  // namespace impl
+
+// -----------------------------------------------------------------------------
+// Public interface.
+
+// Matches a node with name `name`.
+impl::NodeMatcherProperties Name(string name);
+
+// Matches a node with op `op`.
+impl::NodeMatcherProperties Op(string op);
+
+// Matches a node with assigned device `assigned_device`.
+impl::NodeMatcherProperties AssignedDevice(string assigned_device);
+
+// Matches a node with inputs `inputs`.
+//
+// `inputs` are ordered; `inputs`[i] must match input i.
+template <typename... Ts>
+impl::NodeMatcherProperties Inputs(Ts... inputs) {
+  return impl::Inputs({inputs...});
+}
+
+// Matches a node with control dependences `control_deps`.
+//
+// `control_deps` are unordered and will match the control deps of a node in any
+// order.
+template <typename... Ts>
+impl::NodeMatcherProperties CtrlDeps(Ts... control_deps) {
+  return impl::CtrlDeps({control_deps...});
+}
+
+// Matches a constant node with value `val`.
+impl::NodeMatcherProperties ConstantValue(
+    const ::tensorflow::Input::Initializer& val);
+
+// The main gmock matcher.  See file comment for example usage.
+template <typename... Ts>
+::testing::Matcher<const Node*> NodeWith(Ts... args) {
+  std::array<impl::NodeMatcherProperties, sizeof...(Ts)> array = {args...};
+  return impl::NodeWith(array);
+}
+
+::testing::Matcher<const Node*> Const(
+    const ::tensorflow::Input::Initializer& val);
+}  // namespace matchers
+
+// If `g` has a node named `name` returns it, otherwise returns null.
+Node* FindNodeByName(Graph* g, absl::string_view name);
+}  // namespace testing
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
diff --git a/tensorflow/compiler/jit/node_matchers_test.cc b/tensorflow/compiler/jit/node_matchers_test.cc
new file mode 100644
index 0000000..93a8994
--- /dev/null
+++ b/tensorflow/compiler/jit/node_matchers_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/node_matchers.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/math_ops.h"
+
+namespace tensorflow {
+namespace testing {
+namespace {
+
+using ::testing::_;
+
+using testing::matchers::AssignedDevice;
+using testing::matchers::ConstantValue;
+using testing::matchers::CtrlDeps;
+using testing::matchers::Inputs;
+using testing::matchers::Name;
+using testing::matchers::NodeWith;
+using testing::matchers::Op;
+
+template <typename M, typename T>
+string Explain(const T& t, const M& m) {
+  ::testing::StringMatchResultListener listener;
+  EXPECT_THAT(t, ::testing::Not(m));  // For the error message.
+  EXPECT_FALSE(m.MatchAndExplain(t, &listener));
+  return listener.str();
+}
+
+TEST(NodeMatchers, CheckAgainstConstant) {
+  Scope root = Scope::NewRootScope().ExitOnError();
+  Output placeholder =
+      ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
+
+  EXPECT_THAT(placeholder.node(), NodeWith(Op("Placeholder")));
+  EXPECT_THAT(placeholder.node(), NodeWith(Name("placeholder")));
+  EXPECT_THAT(placeholder.node(),
+              NodeWith(Op("Placeholder"), Name("placeholder")));
+  EXPECT_THAT(placeholder.node(),
+              NodeWith(Name("placeholder"), Op("Placeholder")));
+  EXPECT_THAT(placeholder.node(), NodeWith(Inputs()));
+  EXPECT_THAT(placeholder.node(),
+              NodeWith(Op("Placeholder"), Name("placeholder"), Inputs()));
+
+  EXPECT_EQ(Explain(placeholder.node(), NodeWith(Op("Add"))),
+            "\nexpected op Add but found Placeholder");
+  EXPECT_EQ(Explain(placeholder.node(), NodeWith(Name("add"))),
+            "\nexpected name add but found placeholder");
+  EXPECT_EQ(Explain(placeholder.node(), NodeWith(Inputs(NodeWith()))),
+            "\nexpected 1 inputs but node has 0");
+}
+
+TEST(NodeMatchers, CheckAgainstBinary) {
+  Scope root = Scope::NewRootScope().ExitOnError();
+
+  Output placeholder_a =
+      ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
+  Output placeholder_b =
+      ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
+  Output add = ops::Add(root.WithOpName("add"), placeholder_a, placeholder_b);
+
+  EXPECT_THAT(add.node(), NodeWith(Op("Add"), Name("add"),
+                                   Inputs(NodeWith(Name("placeholder_a")),
+                                          NodeWith(Name("placeholder_b")))));
+
+  EXPECT_EQ(Explain(add.node(), NodeWith(Inputs())),
+            "\nexpected 0 inputs but node has 2");
+  EXPECT_EQ(
+      Explain(add.node(), NodeWith(Inputs(NodeWith(Name("blah")), _))),
+      "\ninput 0 does not match expected:\nname: blah, \nsource does not match "
+      "expected name: blah\n\t\nexpected name blah but found placeholder_a");
+  EXPECT_EQ(
+      Explain(add.node(), NodeWith(Inputs(_, NodeWith(Name("blah"))))),
+      "\ninput 1 does not match expected:\nname: blah, \nsource does not match "
+      "expected name: blah\n\t\nexpected name blah but found placeholder_b");
+}
+
+TEST(NodeMatchers, CheckControlDependence) {
+  Scope root = Scope::NewRootScope().ExitOnError();
+
+  Output placeholder_a =
+      ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
+  Output placeholder_b =
+      ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
+  Output placeholder_c =
+      ops::Placeholder(root.WithOpName("placeholder_c"), DT_FLOAT);
+  Output placeholder_d =
+      ops::Placeholder(root.WithOpName("placeholder_d"), DT_FLOAT);
+
+  root.graph()->AddControlEdge(placeholder_a.node(), placeholder_c.node());
+  root.graph()->AddControlEdge(placeholder_b.node(), placeholder_c.node());
+
+  EXPECT_THAT(placeholder_c.node(),
+              NodeWith(Name("placeholder_c"),
+                       CtrlDeps(NodeWith(Name("placeholder_a")),
+                                NodeWith(Name("placeholder_b")))));
+  EXPECT_THAT(placeholder_d.node(),
+              NodeWith(Name("placeholder_d"), CtrlDeps()));
+
+  EXPECT_EQ(
+      Explain(placeholder_c.node(), NodeWith(CtrlDeps())),
+      "ctrl_deps, which has 2 elements, does not match expected: is empty");
+  EXPECT_EQ(Explain(placeholder_d.node(), NodeWith(CtrlDeps(NodeWith()))),
+            "ctrl_deps does not match expected: has 1 element and that element "
+            "is any node");
+}
+
+TEST(NodeMatchers, ConstVaulue) {
+  Scope root = Scope::NewRootScope().ExitOnError();
+  Output placeholder =
+      ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
+  Output const_0d = ops::Const(root.WithOpName("const_0d"), 42);
+
+  Output const_2d = ops::Const(root.WithOpName("const_2d"), {{1, 2}, {4, 3}});
+
+  EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42)));
+  EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42), Name("const_0d")));
+
+  EXPECT_THAT(const_2d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}})));
+
+  EXPECT_EQ(Explain(placeholder.node(), NodeWith(ConstantValue(42))),
+            "\nexpected op Const but found Placeholder");
+  EXPECT_EQ(
+      Explain(const_0d.node(), NodeWith(ConstantValue(43))),
+      "\nmismatch in constant tensor at index 0 expected = 43 actual = 42");
+  EXPECT_EQ(
+      Explain(const_0d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}}))),
+      "\nwas looking for tensor with 4 elements, found tensor with 1 elements");
+  EXPECT_EQ(
+      Explain(const_2d.node(), NodeWith(ConstantValue(42))),
+      "\nwas looking for tensor with 1 elements, found tensor with 4 elements");
+}
+
+TEST(NodeMatchers, AssignedDevice) {
+  Scope root = Scope::NewRootScope().ExitOnError();
+
+  Output placeholder_a =
+      ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
+  Output placeholder_b =
+      ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
+
+  Output assigned_add =
+      ops::Add(root.WithOpName("assigned_add"), placeholder_a, placeholder_b);
+  assigned_add.node()->set_assigned_device_name(
+      "/job:localhost/replica:0/task:0/device:CPU:0");
+
+  Output unassigned_add =
+      ops::Add(root.WithOpName("unassigned_add"), placeholder_a, placeholder_b);
+
+  EXPECT_THAT(
+      assigned_add.node(),
+      NodeWith(AssignedDevice("/job:localhost/replica:0/task:0/device:CPU:0")));
+  EXPECT_THAT(unassigned_add.node(), NodeWith(AssignedDevice("")));
+
+  EXPECT_EQ(Explain(unassigned_add.node(),
+                    NodeWith(AssignedDevice(
+                        "/job:localhost/replica:0/task:0/device:CPU:0"))),
+            "\nexpected assigned_device "
+            "/job:localhost/replica:0/task:0/device:CPU:0 but found \"\"");
+}
+
+}  // namespace
+}  // namespace testing
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD
index 13804c6..f722245 100644
--- a/tensorflow/compiler/jit/ops/BUILD
+++ b/tensorflow/compiler/jit/ops/BUILD
@@ -4,9 +4,17 @@
     default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
 )
 
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+
 cc_library(
     name = "xla_ops",
     srcs = ["xla_ops.cc"],
     deps = ["//tensorflow/core:framework"],
     alwayslink = 1,
 )
+
+tf_gen_op_wrapper_py(
+    name = "xla_ops_wrapper_py",
+    out = "xla_ops.py",
+    deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
+)
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index f2473d9..6b4cdaa 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -13,10 +13,14 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
 
 namespace tensorflow {
 
+using shape_inference::InferenceContext;
+
 REGISTER_OP("XlaLaunch")
     .Input("constants: Tconstants")
     .Attr("Tconstants: list(type) >= 0")
@@ -32,4 +36,62 @@
     .SetIsStateful()
     .Doc("XLA Launch Op. For use by the XLA JIT only.");
 
+REGISTER_OP("XlaClusterOutput")
+    .Input("input: T")
+    // Note: when replication is supported, this op will have N outputs.
+    .Output("outputs: T")
+    .Attr("T: type")
+    .SetShapeFn([](InferenceContext* c) {
+      for (int i = 0; i < c->num_outputs(); ++i) {
+        c->set_output(i, c->input(0));
+      }
+      return Status::OK();
+    })
+    .Doc(
+        "Operator that connects the output of an XLA computation to other "
+        "consumer graph nodes.");
+
+REGISTER_OP("_XlaCompile")
+    .Input("constants: Tconstants")
+    .Attr("Tconstants: list(type) >= 0")
+    .Input("args: Targs")
+    .Attr("Targs: list(type) >= 0")
+    .Input("resources: Nresources * resource")
+    .Attr("Nresources: int >= 0")
+    .Output("key: string")
+    .Output("compilation_successful: bool")
+    .Attr("function: func")
+    // The compilation cache is stateful.
+    .SetIsStateful()
+    .Doc(R"(XLA Compile Op. For use by the XLA JIT only.
+
+Compiles a TensorFlow function into an XLA LocalExecutable and returns a key
+that _XlaRun can use to look up the LocalExecutable and execute it.
+
+key: A key that can be used to look up the local executable compiled by the
+   node and associated metadata.
+
+compilation_successful: True iff the compilation was successful.  Always true
+for now.
+)");
+
+REGISTER_OP("_XlaRun")
+    // TODO(sanjoy): We don't need constants and Tconstants and they should be
+    // removed.
+    .Input("constants: Tconstants")
+    .Attr("Tconstants: list(type) >= 0")
+    .Input("args: Targs")
+    .Attr("Targs: list(type) >= 0")
+    .Output("results: Tresults")
+    .Attr("Tresults: list(type) >= 0")
+    .Input("key: string")
+    // XLA random-number generation ops are stateful.
+    // TODO(phawkins): create stateful and non-stateful variants of _XlaRun.
+    .SetIsStateful()
+    .Doc(R"(XLA Run Op. For use by the XLA JIT only.
+
+Executes a TensorFlow function previously compiled into a LocalExecutable by an
+_XlaCompile op.
+)");
+
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
index 35872da..0feb73a 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -60,9 +60,9 @@
   void Compute(OpKernelContext* ctx) override { CHECK(false); }
 };
 
-class FakeResourceVarUpdateOp : public OpKernel {
+class FakeResourceUpdateOp : public OpKernel {
  public:
-  explicit FakeResourceVarUpdateOp(OpKernelConstruction* context)
+  explicit FakeResourceUpdateOp(OpKernelConstruction* context)
       : OpKernel(context) {}
 
   void Compute(OpKernelContext* ctx) override { CHECK(false); }
@@ -74,10 +74,9 @@
                             .HostMemory("host_out"),
                         FakeBinaryOp);
 
-REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate")
-                            .Device(DEVICE_CPU)
-                            .HostMemory("something_else"),
-                        FakeResourceVarUpdateOp);
+REGISTER_KERNEL_BUILDER(
+    Name("FakeResourceUpdate").Device(DEVICE_CPU).HostMemory("something_else"),
+    FakeResourceUpdateOp);
 
 Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
   FixupSourceAndSinkEdges(graph->get());
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 7e159e3..003c1d8 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -16,7 +16,7 @@
 // Registers the XLA_CPU device, which is an XlaDevice instantiation that runs
 // operators using XLA via the XLA "Host" (CPU) backend.
 
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
 #include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
 #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
 #include "tensorflow/compiler/jit/xla_device.h"
@@ -65,10 +65,14 @@
 
 // Kernel registrations
 
-constexpr std::array<DataType, 7> kAllXlaCpuTypes = {
-    {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 12> kAllXlaCpuTypes = {
+    {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+     DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
 
 REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes);
+
 REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 51797de..32fce2b 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -434,6 +434,16 @@
   return status;
 }
 
+void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) {
+  mutex_lock lock(mu_);
+  sync_on_completion_ = sync_on_completion;
+}
+
+bool XlaDevice::RequiresSyncOnCompletion() const {
+  mutex_lock lock(mu_);
+  return sync_on_completion_;
+}
+
 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
                                                    const char* jit_device) {
   // Any op assigned to the device that isn't rewritten by the graph rewriter
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index 92891ff..0f06b3f 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -151,6 +151,12 @@
   // information for GPU and TPU devices.
   Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
 
+  // Instructs this XlaDevice to return 'sync_on_completion' for
+  // RequiresSyncOnCompletion().
+  void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
+
+  bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
+
  private:
   xla::LocalClient* client() const;
   Allocator* GetAllocatorLocked(AllocatorAttributes attr)
@@ -165,7 +171,7 @@
   static Status GetMetadataFromDevice(DeviceBase* device,
                                       const XlaDevice::Metadata** metadata);
 
-  mutex mu_;
+  mutable mutex mu_;
   // The metadata of this XlaDevice.
   const Metadata xla_metadata_;
   // Which hardware device in the client's platform this XlaDevice controls.
@@ -207,6 +213,10 @@
 
   // Thread pool used for running closures
   std::unique_ptr<thread::ThreadPool> thread_pool_;
+
+  // True if the device requires XlaDevice::Sync to be called on completion
+  // regardless of status.
+  bool sync_on_completion_ GUARDED_BY(mu_) = false;
 };
 
 // Builds OpKernel registrations on 'device' for the JIT operators
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 49c8582..6392439 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -65,6 +65,17 @@
                               .HostMemory("resources"),   \
                           KERNEL);
 
+#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \
+  REGISTER_KERNEL_BUILDER(Name("_XlaCompile")              \
+                              .Device(DEVICE)              \
+                              .HostMemory("constants")     \
+                              .HostMemory("resources"),    \
+                          KERNEL);
+
+#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
+  REGISTER_KERNEL_BUILDER(                             \
+      Name("_XlaRun").Device(DEVICE).HostMemory("constants"), KERNEL);
+
 #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES)                             \
   REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp);               \
   REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp);               \
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index ef4466f..6097955 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -16,7 +16,7 @@
 // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
 // operators using XLA via the XLA "CUDA" (GPU) backend.
 
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
 #include "tensorflow/compiler/jit/xla_device.h"
 #include "tensorflow/compiler/jit/xla_device_ops.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -74,11 +74,14 @@
 
 // Kernel registrations
 
-constexpr std::array<DataType, 8> kAllXlaGpuTypes = {
-    {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
-     DT_BFLOAT16}};
+constexpr std::array<DataType, 13> kAllXlaGpuTypes = {
+    {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+     DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
 
 REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
+
 REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc
index 4574559..19e681a 100644
--- a/tensorflow/compiler/jit/xla_interpreter_device.cc
+++ b/tensorflow/compiler/jit/xla_interpreter_device.cc
@@ -15,7 +15,7 @@
 
 // Registers the XLA_INTERPRETER device which exposes the XLA Interpreter.
 
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
 #include "tensorflow/compiler/jit/xla_device.h"
 #include "tensorflow/compiler/jit/xla_device_ops.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -72,6 +72,10 @@
 
 REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp,
                            kExecAllTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp,
+                            kExecAllTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes);
+
 REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes);
 REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter);
 
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index affeab4..07a93e9 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -42,7 +42,7 @@
 }  // anonymous namespace
 
 std::map<int, OptionalTensor> SnapshotResourceVariables(
-    OpKernelContext* ctx, const std::vector<int>& variables) {
+    OpKernelContext* ctx, absl::Span<const int> variables) {
   std::map<int, OptionalTensor> snapshot;
   for (int i : variables) {
     Var* variable = nullptr;
@@ -275,6 +275,8 @@
       VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
               << DataTypeString(type);
       if (type == DT_RESOURCE) {
+        TF_RET_CHECK(kernel->outputs[i].input_index >= 0)
+            << "Invalid input for outputs " << i;
         ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
       } else {
         se::DeviceMemoryBase buffer = output.buffer({output_num});
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 7ac275f..fa7a5e5 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -29,6 +29,7 @@
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/kernels/variable_ops.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
 
 namespace tensorflow {
 class XlaAllocator;
@@ -43,7 +44,7 @@
 // resource variable is not initialized, the corresponding OptionalTensor
 // will have its `present` field set to false.
 std::map<int, OptionalTensor> SnapshotResourceVariables(
-    OpKernelContext* ctx, const std::vector<int>& variables);
+    OpKernelContext* ctx, absl::Span<const int> variables);
 
 // Adapter class that wraps a Tensorflow allocator as an XLA allocator.
 // Assumes that the Tensorflow allocator permits asynchronous deallocation:
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 2176eae..3cf74fa 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -277,9 +277,10 @@
     ],
 )
 
+# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors
 tf_xla_py_test(
     name = "concat_ops_test",
-    size = "medium",
+    size = "large",
     srcs = ["concat_ops_test.py"],
     deps = [
         ":xla_test",
@@ -977,7 +978,7 @@
     name = "gather_test",
     size = "medium",
     srcs = ["gather_test.py"],
-    tags = ["noasan"],  # times out, http://b/78599043
+    tags = ["optonly"],
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
@@ -1197,6 +1198,19 @@
 )
 
 tf_xla_py_test(
+    name = "quantized_ops_test",
+    size = "small",
+    srcs = ["quantized_ops_test.py"],
+    deps = [
+        ":xla_test",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+tf_xla_py_test(
     name = "xla_ops_test",
     size = "medium",
     srcs = ["xla_ops_test.py"],
diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py
index 4155342..68f52e7 100644
--- a/tensorflow/compiler/tests/argminmax_test.py
+++ b/tensorflow/compiler/tests/argminmax_test.py
@@ -50,12 +50,12 @@
 
   def testArgMinMax(self):
     # Complex numbers do not support argmin/argmax.
-    minmax_types = set(self.numeric_types) - set(self.complex_types)
+    minmax_types = self.all_types & {np.int32, np.int64}
     for dtype in minmax_types:
       # output_type is a numpy data type that is used to specify the desired
       # output type of the op as well as to convert the Python number to the
       # array scalar of the type.
-      for output_type in self.int_types:
+      for output_type in minmax_types:
         self._assertOpOutputMatchesExpected(
             math_ops.argmax,
             axis=0,
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 17280e4..900e84a 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -210,7 +210,7 @@
             equality_test=self.ListsAreClose)
 
   def testIntOps(self):
-    for dtype in self.int_types:
+    for dtype in self.signed_int_types:
       self._testBinary(
           gen_math_ops.truncate_div,
           np.array([3, 3, -1, -9, -8], dtype=dtype),
@@ -287,7 +287,8 @@
           dtype(7),
           expected=np.array([[-6], [-5]], dtype=dtype))
 
-      if dtype not in self.complex_types:  # min/max not supported for complex
+      # min/max not supported for complex
+      if dtype not in self.complex_types | {np.uint8, np.int8}:
         self._testBinary(
             math_ops.maximum,
             np.array([1, 2], dtype=dtype),
@@ -337,7 +338,7 @@
           expected=np.array([[70], [14]], dtype=dtype))
 
       # Complex support for squared_difference is incidental, see b/68205550
-      if dtype not in self.complex_types:
+      if dtype not in self.complex_types | {np.uint8, np.int8}:
         self._testBinary(
             math_ops.squared_difference,
             np.array([1, 2], dtype=dtype),
@@ -567,7 +568,7 @@
           expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
 
   def testIntDivision(self):
-    for dtype in self.int_types:
+    for dtype in self.signed_int_types:
       self._testDivision(dtype)
 
   def testFloatDivision(self):
@@ -588,7 +589,7 @@
         expected=np.array([1, 1, -1, 0], dtype=dtype))
 
   def testIntRemainder(self):
-    for dtype in self.int_types:
+    for dtype in self.signed_int_types - {np.int8}:
       self._testRemainder(dtype)
 
   def testFloatRemainder(self):
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
index 7b114d4..1d3979b 100644
--- a/tensorflow/compiler/tests/build_defs.bzl
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -2,90 +2,103 @@
 
 load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
 load("//tensorflow/compiler/tests:plugin.bzl", "plugins")
+load(
+    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "tf_cuda_tests_tags",
+)
 
 def all_backends():
-  b = ["cpu"] + plugins.keys()
-  if cuda_is_configured():
-    return b + ["gpu"]
-  else:
-    return b
-
-def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
-                   disabled_backends=None, **kwargs):
-  """Generates py_test targets, one per XLA backend.
-
-  This rule generates py_test() targets named name_backend, for each backend
-  in all_backends(). The rule also generates a test suite with named `name` that
-  tests all backends for the test.
-
-  For example, the following rule generates test cases foo_test_cpu,
-  foo_test_gpu, and a test suite name foo_test that tests both.
-  tf_xla_py_test(
-      name="foo_test",
-      srcs="foo_test.py",
-      deps=[...],
-  )
-
-  Args:
-    name: Name of the target.
-    srcs: Sources for the target.
-    deps: Dependencies of the target.
-    tags: Tags to apply to the generated targets.
-    data: Data dependencies of the target.
-    main: Same as py_test's main attribute.
-    disabled_backends: A list of backends that should not be tested. Supported
-      values include "cpu" and "gpu". If not specified, defaults to None.
-    **kwargs: keyword arguments passed onto the generated py_test() rules.
-  """
-  if disabled_backends == None:
-    disabled_backends = []
-
-  enabled_backends = [b for b in all_backends() if b not in disabled_backends]
-  test_names = []
-  for backend in enabled_backends:
-    test_name = "{}_{}".format(name, backend)
-    backend_tags = ["tf_xla_{}".format(backend)]
-    backend_args = []
-    backend_deps = []
-    backend_data = []
-    if backend == "cpu":
-      backend_args += [
-          "--test_device=XLA_CPU",
-          "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
-      ]
-    elif backend == "gpu":
-      backend_args += [
-          "--test_device=XLA_GPU",
-          "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16"
-      ]
-      backend_tags += ["requires-gpu-sm35"]
-    elif backend in plugins:
-      backend_args += ["--test_device=" + plugins[backend]["device"],
-                       "--types=" + plugins[backend]["types"]]
-      backend_tags += plugins[backend]["tags"]
-      backend_args += plugins[backend]["args"]
-      backend_deps += plugins[backend]["deps"]
-      backend_data += plugins[backend]["data"]
+    b = ["cpu"] + plugins.keys()
+    if cuda_is_configured():
+        return b + ["gpu"]
     else:
-      fail("Unknown backend {}".format(backend))
+        return b
 
-    native.py_test(
-        name=test_name,
-        srcs=srcs,
-        srcs_version="PY2AND3",
-        args=backend_args,
-        main="{}.py".format(name) if main == None else main,
-        data=data + backend_data,
-        deps=deps + backend_deps,
-        tags=tags + backend_tags,
-        **kwargs
+def tf_xla_py_test(
+        name,
+        srcs = [],
+        deps = [],
+        tags = [],
+        data = [],
+        main = None,
+        disabled_backends = None,
+        **kwargs):
+    """Generates py_test targets, one per XLA backend.
+
+    This rule generates py_test() targets named name_backend, for each backend
+    in all_backends(). The rule also generates a test suite with named `name` that
+    tests all backends for the test.
+
+    For example, the following rule generates test cases foo_test_cpu,
+    foo_test_gpu, and a test suite name foo_test that tests both.
+    tf_xla_py_test(
+        name="foo_test",
+        srcs="foo_test.py",
+        deps=[...],
     )
-    test_names.append(test_name)
-  native.test_suite(name=name, tests=test_names)
 
-def generate_backend_suites(backends=[]):
-  """Generates per-backend test_suites that run all tests for a backend."""
-  if not backends:
-    backends = all_backends()
-  for backend in backends:
-    native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend])
+    Args:
+      name: Name of the target.
+      srcs: Sources for the target.
+      deps: Dependencies of the target.
+      tags: Tags to apply to the generated targets.
+      data: Data dependencies of the target.
+      main: Same as py_test's main attribute.
+      disabled_backends: A list of backends that should not be tested. Supported
+        values include "cpu" and "gpu". If not specified, defaults to None.
+      **kwargs: keyword arguments passed onto the generated py_test() rules.
+    """
+    if disabled_backends == None:
+        disabled_backends = []
+
+    enabled_backends = [b for b in all_backends() if b not in disabled_backends]
+    test_names = []
+    for backend in enabled_backends:
+        test_name = "{}_{}".format(name, backend)
+        backend_tags = ["tf_xla_{}".format(backend)]
+        backend_args = []
+        backend_deps = []
+        backend_data = []
+        if backend == "cpu":
+            backend_args += [
+                "--test_device=XLA_CPU",
+                "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
+            ]
+        elif backend == "gpu":
+            backend_args += [
+                "--test_device=XLA_GPU",
+                "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
+            ]
+            backend_tags += tf_cuda_tests_tags()
+        elif backend in plugins:
+            backend_args += [
+                "--test_device=" + plugins[backend]["device"],
+                "--types=" + plugins[backend]["types"],
+            ]
+            backend_tags += plugins[backend]["tags"]
+            backend_args += plugins[backend]["args"]
+            backend_deps += plugins[backend]["deps"]
+            backend_data += plugins[backend]["data"]
+        else:
+            fail("Unknown backend {}".format(backend))
+
+        native.py_test(
+            name = test_name,
+            srcs = srcs,
+            srcs_version = "PY2AND3",
+            args = backend_args,
+            main = "{}.py".format(name) if main == None else main,
+            data = data + backend_data,
+            deps = deps + backend_deps,
+            tags = tags + backend_tags,
+            **kwargs
+        )
+        test_names.append(test_name)
+    native.test_suite(name = name, tests = test_names)
+
+def generate_backend_suites(backends = []):
+    """Generates per-backend test_suites that run all tests for a backend."""
+    if not backends:
+        backends = all_backends()
+    for backend in backends:
+        native.test_suite(name = "%s_tests" % backend, tags = ["tf_xla_%s" % backend])
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index 37e5318..2d225ad 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -291,6 +291,41 @@
             ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"):
           array_ops.concat([scalar, scalar, scalar], dim)
 
+  # The purpose of this is to ensure that XLA on GPU will not run out of memory
+  # with too many arguments.
+  def testConcatLargeNumberOfTensors(self):
+    with self.cached_session():
+      with self.test_scope():
+        for concat_dim in range(2):
+          params = {}
+          p = []
+          shape = np.array([7, 13])
+          num_tensors = 1001
+          for i in np.arange(num_tensors):
+            input_shape = shape
+            placeholder = array_ops.placeholder(
+                dtypes.float32, shape=input_shape)
+            p.append(placeholder)
+            params[placeholder] = np.random.rand(*input_shape).astype(
+                np.float32)
+
+          concat_inputs = p
+          c = array_ops.concat(concat_inputs, concat_dim)
+          result = c.eval(feed_dict=params)
+
+          self.assertEqual(result.shape, c.get_shape())
+          cur_offset = 0
+
+          for i in np.arange(num_tensors):
+            # The index into the result is the ':' along all dimensions
+            # except the concat_dim. slice(0, size) is used for ':', and
+            # a list of slices is used to index into result.
+            index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)]
+            index[concat_dim] = slice(
+                cur_offset, cur_offset + params[p[i]].shape[concat_dim])
+            cur_offset += params[p[i]].shape[concat_dim]
+            self.assertAllEqual(result[index], params[p[i]])
+
 
 class ConcatOffsetTest(xla_test.XLATestCase):
 
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index 0af74c2..9390870 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -45,17 +45,21 @@
   return any([substr in x for x in labels])
 
 
-def XlaLaunchOpCount(labels):
-  """Count how many XlaLaunch labels are present."""
-  return sum("XlaLaunch(" in x for x in labels)
-
-
 class DenseLayerTest(test.TestCase):
 
+  def countXlaOps(self, labels):
+    """Count how many XlaCompile/XlaRun labels are present."""
+    xla_compile_count = sum("XlaCompile(" in x for x in labels)
+    xla_run_count = sum("XlaRun(" in x for x in labels)
+    self.assertEqual(xla_compile_count, xla_run_count)
+    return xla_run_count
+
+
   def testDenseLayerAutoJit(self):
     """Tests dense layer compilation in auto-jit mode.
 
-    Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
+    Dense layer should be compiled into a single XlaCompile/XlaRun op pair in
+    auto-jit mode.
     """
 
     os.environ["TF_XLA_FLAGS"] = (
@@ -77,14 +81,14 @@
               trace_level=config_pb2.RunOptions.FULL_TRACE))
 
     labels = GetRunMetadataLabels(run_metadata)
-    self.assertEqual(1, XlaLaunchOpCount(labels))
+    self.assertEqual(1, self.countXlaOps(labels))
     self.assertFalse(InLabels(labels, "MatMult"))
 
   def testDenseLayerJitScopeDefinedShape(self):
     """Tests that the dense layer node is properly compiled in jit scope.
 
     Dense layer with static shape input tensor should be compiled into a single
-    XlaLaunch op by XLA.
+    XlaCompile/XlaRun op pair by XLA.
     """
 
     with self.cached_session() as sess:
@@ -101,7 +105,7 @@
               trace_level=config_pb2.RunOptions.FULL_TRACE))
 
     labels = GetRunMetadataLabels(run_metadata)
-    self.assertEqual(1, XlaLaunchOpCount(labels))
+    self.assertEqual(1, self.countXlaOps(labels))
     # No need to check whether ListDiff is compiled or not because ListDiff op
     # is not used when input tensor shape is fully defined.
 
@@ -111,7 +115,8 @@
     Dense layer uses shape op to get shape of input tensor if its shape is not
     fully defined. XLA does not cluster shape op with other operators. But in
     experimental_jit_scope, XLA is forced to compile shape op into its own
-    cluster, causing dense layer to be split into TWO XlaLaunch ops.
+    cluster, causing dense layer to be split into TWO XlaCompile/XlaRun op
+    pairs.
     """
 
     with self.cached_session() as sess:
@@ -128,7 +133,7 @@
               trace_level=config_pb2.RunOptions.FULL_TRACE))
 
     labels = GetRunMetadataLabels(run_metadata)
-    self.assertEqual(2, XlaLaunchOpCount(labels))
+    self.assertEqual(2, self.countXlaOps(labels))
     self.assertFalse(InLabels(labels, "MatMult"))
 
 
diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py
index 089d95d..a38e1ed 100644
--- a/tensorflow/compiler/tests/gather_test.py
+++ b/tensorflow/compiler/tests/gather_test.py
@@ -51,7 +51,7 @@
           indices_tf = constant_op.constant(indices)
           gather_t = array_ops.gather(params, indices_tf)
           gather_val = session.run(gather_t, feed_dict={params: params_np})
-          np_val = params_np[indices]
+          np_val = constant_op.constant(params_np[indices])
           self.assertAllEqual(np_val, gather_val)
 
   def testScalar2D(self):
@@ -65,7 +65,8 @@
           indices = constant_op.constant(2)
           gather_t = array_ops.gather(params, indices, axis=axis)
           gather_val = session.run(gather_t, feed_dict={params: params_np})
-          expected = np.take(params_np, 2, axis=axis)
+          expected = constant_op.constant(
+              np.take(params_np, 2, axis=axis), dtype)
           self.assertAllEqual(expected, gather_val)
 
   def testSimpleTwoD32(self):
@@ -80,7 +81,8 @@
           indices = constant_op.constant([0, 1, 0, 2])
           gather_t = array_ops.gather(params, indices, axis=axis)
           gather_val = session.run(gather_t, feed_dict={params: params_np})
-          expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+          expected = constant_op.constant(
+              np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
           self.assertAllEqual(expected, gather_val)
 
   def testSimpleTwoD32_Int64Indices(self):
@@ -103,7 +105,8 @@
                   params: params_np,
                   indices: indices_np
               })
-          expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+          expected = constant_op.constant(
+              np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
           self.assertAllEqual(expected, gather_val)
 
   def testHigherRank(self):
@@ -119,7 +122,8 @@
             tf_indices = constant_op.constant(indices, dtype=dtypes.int32)
             gather = array_ops.gather(tf_params, tf_indices, axis=axis)
             gather_value = sess.run(gather, feed_dict={tf_params: params})
-            gather_np = np.take(params, indices, axis=axis)
+            gather_np = constant_op.constant(
+                np.take(params, indices, axis=axis), dtype)
             self.assertAllEqual(gather_np, gather_value)
 
   def testIndicesWithDifferentDimensions(self):
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index 6fe5a66..bbe746e 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -605,10 +605,6 @@
 class NonMaxSuppressionTest(xla_test.XLATestCase):
 
   def testNMS128From1024(self):
-    # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
-    if self.device in ["XLA_CPU", "XLA_GPU"]:
-      return
-
     with compat.forward_compatibility_horizon(2018, 8, 8):
       num_boxes = 1024
       boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
@@ -644,10 +640,6 @@
         self.assertEqual(indices_tf.size, max_output_size)
 
   def testNMS3From6Boxes(self):
-    # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
-    if self.device in ["XLA_CPU", "XLA_GPU"]:
-      return
-
     with compat.forward_compatibility_horizon(2018, 8, 8):
       # Three boxes are selected based on IOU.
       boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
@@ -693,10 +685,6 @@
     # Three boxes are selected based on IOU.
     # One is filtered out by score threshold.
 
-    # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
-    if self.device in ["XLA_CPU", "XLA_GPU"]:
-      return
-
     with compat.forward_compatibility_horizon(2018, 8, 8):
       boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
                     [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 0839fb1..de68ff0 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -77,11 +77,11 @@
   return any([substr in x for x in labels])
 
 
-def MetadataHasXlaLaunch(run_metadata):
-  """Returns true if there is a XlaLaunch kernel in run_metadata's timeline."""
+def MetadataHasXlaOp(run_metadata):
+  """Returns true if there are XlaRun kernels in run_metadata's timeline."""
 
   # TODO(phawkins): find a less hacky way to test whether a kernel ran.
-  return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch")
+  return InLabels(RunMetadataLabels(run_metadata), "XlaRun")
 
 
 class JitLaunchTest(test.TestCase):
@@ -90,9 +90,10 @@
   # Verifies that the outputs match and that XLA was invoked. 'fn' must take
   # the same number of tensors as arguments that are in 'args', and must return
   # a tuple of output tensors.
-  # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node
-  # actually ran. However, it is sometimes possible for XlaLaunch ops to be
-  # constant-folded away, so the check is optional.
+  #
+  # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun
+  # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun
+  # ops to be constant-folded away, so the check is optional.
   def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
     with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
       placeholders = []
@@ -115,7 +116,7 @@
       print("Compiled Result {}".format(compiled))
 
       if require_kernel_launch:
-        self.assert_(MetadataHasXlaLaunch(run_metadata))
+        self.assert_(MetadataHasXlaOp(run_metadata))
 
         direct = sess.run(direct_op, feeds)
         print("Direct Result {}".format(direct))
@@ -149,10 +150,10 @@
       y = math_ops.add(x, x)
       return y, y
 
-    # Exercises compling a function (say, Foo) which calls another
-    # function (say, Bar) which is not inlined. When the compiler compiles
-    # Foo, it needs to symbolic execute Bar correctly regardless whether
-    # Bar is inlined or not.
+    # Exercises compiling a function (say, Foo) which calls another function
+    # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs
+    # to symbolically execute Bar correctly regardless of whether Bar is inlined
+    # or not.
 
     # TODO(b/36139787): Re-enable this test when noinline works again.
     # Tests compiled=True and noinline=True.
@@ -259,7 +260,7 @@
         # TODO(phawkins): really we would like to test that there were exactly
         # two kernel launches. However, we have no reliable way to determine
         # that.
-        self.assert_(MetadataHasXlaLaunch(run_metadata))
+        self.assert_(MetadataHasXlaOp(run_metadata))
 
         expected = np.square(np.dot(dx, dw) + db)
         self.assertAllClose(expected, output, rtol=1e-1)
@@ -289,7 +290,7 @@
                      run_metadata=run_metadata,
                      options=config_pb2.RunOptions(
                          trace_level=config_pb2.RunOptions.FULL_TRACE))
-      self.assert_(MetadataHasXlaLaunch(run_metadata))
+      self.assert_(MetadataHasXlaOp(run_metadata))
       self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
 
   def testIgnoredArguments(self):
@@ -313,7 +314,7 @@
                      run_metadata=run_metadata,
                      options=config_pb2.RunOptions(
                          trace_level=config_pb2.RunOptions.FULL_TRACE))
-      self.assert_(MetadataHasXlaLaunch(run_metadata))
+      self.assert_(MetadataHasXlaOp(run_metadata))
       self.assertAllClose(28, out)
 
   def testLoops(self):
@@ -331,7 +332,7 @@
                            run_metadata=run_metadata,
                            options=config_pb2.RunOptions(
                                trace_level=config_pb2.RunOptions.FULL_TRACE))
-      self.assert_(MetadataHasXlaLaunch(run_metadata))
+      self.assert_(MetadataHasXlaOp(run_metadata))
       self.assertAllClose(result, np.float32(95), rtol=1e-1)
 
   def testCond(self):
@@ -356,7 +357,7 @@
                            run_metadata=run_metadata,
                            options=config_pb2.RunOptions(
                                trace_level=config_pb2.RunOptions.FULL_TRACE))
-      self.assert_(MetadataHasXlaLaunch(run_metadata))
+      self.assert_(MetadataHasXlaOp(run_metadata))
       self.assertAllClose(result, np.float32(6), rtol=1e-1)
 
   def testNestedFunction(self):
@@ -441,14 +442,16 @@
     self.assertFalse(InLabels(labels, "Log"))
     self.assertTrue(InLabels(labels, "Reciprocal"))
     self.assertTrue(InLabels(labels, "Mul"))
-    self.assertFalse(InLabels(labels, "XlaLaunch"))
+    self.assertFalse(InLabels(labels, "XlaCompile"))
+    self.assertFalse(InLabels(labels, "XlaRun"))
 
-    # Compile the backprop. One XlaLaunch.
+    # Compile the backprop. One XlaCompile/XlaRun pair.
     labels = _Run(compiled=True)
     self.assertFalse(InLabels(labels, "Log"))
     self.assertFalse(InLabels(labels, "Reciprocal"))
     self.assertFalse(InLabels(labels, "Mul"))
-    self.assertTrue(InLabels(labels, "XlaLaunch"))
+    self.assertTrue(InLabels(labels, "XlaCompile"))
+    self.assertTrue(InLabels(labels, "XlaRun"))
 
 
 class ElementWiseFusionTest(test.TestCase):
@@ -482,9 +485,12 @@
               trace_level=config_pb2.RunOptions.FULL_TRACE))
 
       labels = RunMetadataLabels(run_metadata)
-      count = sum("XlaLaunch(" in x for x in labels)
 
-      return output, count
+      xla_compile_count = sum("XlaCompile(" in x for x in labels)
+      xla_run_count = sum("XlaRun(" in x for x in labels)
+      self.assertEqual(xla_compile_count, xla_run_count)
+
+      return output, xla_run_count
 
   def testElementWiseClustering(self):
     arg0 = np.random.rand(2, 2).astype(np.float32)
diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py
new file mode 100644
index 0000000..80c3385
--- /dev/null
+++ b/tensorflow/compiler/tests/quantized_ops_test.py
@@ -0,0 +1,48 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for quantized operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class QuantizedOpsTest(xla_test.XLATestCase):
+
+  # Verify that quantized types can be clustered by XLA.
+  def testQuantizedTypeRoundtrip(self):
+    with self.cached_session() as session:
+      for dtype in self.quantized_tf_types:
+        in_values = np.array([1, 2, 3, 4, 5, 6])
+        expected = [[1, 2], [3, 4], [5, 6]]
+        with self.test_scope():
+          p = array_ops.placeholder(dtype=dtypes.int32)
+          x = math_ops.cast(p, dtype)
+          x = array_ops.reshape(x, [3, 2])
+
+        value = session.run(x, {p: in_values})
+        self.assertAllEqual(value, expected)
+
+
+if __name__ == "__main__":
+  googletest.main()
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 6e18344..36ef6ed 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -35,7 +35,8 @@
   """Test cases for random-number generating operators."""
 
   def _random_types(self):
-    return set(self.numeric_types) - set(self.complex_types)
+    return set(self.numeric_types) - set(
+        self.complex_types) - {np.uint8, np.int8}
 
   def _testRngIsNotConstant(self, rng, dtype):
     # Tests that 'rng' does not always return the same value.
@@ -68,9 +69,8 @@
     def rng(dtype):
       return random_ops.random_normal(shape=[2], dtype=dtype)
 
-    # TODO(b/34339814): implement inverse erf support for non-F32 types.
-    dtype = dtypes.float32
-    self._testRngIsNotConstant(rng, dtype)
+    for dtype in self._random_types() & self.float_types:
+      self._testRngIsNotConstant(rng, dtype)
 
   def testRandomUniformIsInRange(self):
     for dtype in self._random_types():
@@ -92,13 +92,13 @@
     def rng(dtype):
       return random_ops.truncated_normal(shape=[2], dtype=dtype)
 
-    # TODO(b/34339814): implement inverse erf support for non-F32 types.
-    self._testRngIsNotConstant(rng, dtypes.float32)
+    for dtype in self._random_types() & self.float_types:
+      self._testRngIsNotConstant(rng, dtype)
 
   def testTruncatedNormalIsInRange(self):
     count = 10000000
-    # TODO(b/34339814): implement inverse erf support for non-F32 types.
-    for dtype in [dtypes.float32]:
+    # TODO(b/34339814): make this test work with 16 bit float types.
+    for dtype in self._random_types() & {dtypes.float32, dtypes.float64}:
       with self.cached_session() as sess:
         with self.test_scope():
           x = random_ops.truncated_normal(shape=[count], dtype=dtype)
@@ -144,9 +144,6 @@
         self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3)
 
   def testShuffle1d(self):
-    # TODO(b/26783907): this test requires the CPU backend to implement sort.
-    if self.device in ["XLA_CPU"]:
-      return
     with self.cached_session() as sess:
       with self.test_scope():
         x = math_ops.range(1 << 16)
diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py
index 60c2337..abc822e 100644
--- a/tensorflow/compiler/tests/reverse_sequence_op_test.py
+++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py
@@ -85,7 +85,7 @@
 
   def testSeqLength(self):
     for dtype in self.all_types:
-      for seq_dtype in self.int_types:
+      for seq_dtype in self.all_types & {np.int32, np.int64}:
         self._testBasic(dtype, seq_dtype)
 
 
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 51c04b5..dbf4beb 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -48,10 +48,6 @@
         self.assertAllClose(v, result, rtol=1e-3)
 
   def testSort(self):
-    # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
-    if self.device in ["XLA_CPU", "XLA_GPU"]:
-      return
-
     supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
     for dtype in supported_types.intersection(self.numeric_types):
       x = np.arange(101, dtype=dtype)
@@ -60,10 +56,6 @@
           xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
 
   def testTopK(self):
-    # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
-    if self.device in ["XLA_CPU", "XLA_GPU"]:
-      return
-
     supported_types = set(
         [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
     for dtype in supported_types.intersection(self.numeric_types):
@@ -89,10 +81,6 @@
               expected=[x[indices].astype(dtype), indices])
 
   def testTopK2D(self):
-    # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
-    if self.device in ["XLA_CPU", "XLA_GPU"]:
-      return
-
     supported_types = set(
         [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
     for dtype in supported_types.intersection(self.numeric_types):
@@ -122,10 +110,6 @@
 
   def testTopKZeros(self):
     """Tests that positive and negative zeros sort correctly."""
-    # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
-    if self.device in ["XLA_CPU", "XLA_GPU"]:
-      return
-
     # Only bfloat16 is implemented.
     bfloat16 = dtypes.bfloat16.as_numpy_dtype
     if bfloat16 not in self.numeric_types:
@@ -144,10 +128,6 @@
 
   def testTopKInfinities(self):
     """Tests that positive and negative infinity sort correctly."""
-    # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
-    if self.device in ["XLA_CPU", "XLA_GPU"]:
-      return
-
     # Only bfloat16 is implemented.
     bfloat16 = dtypes.bfloat16.as_numpy_dtype
     if bfloat16 not in self.numeric_types:
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index 1bea7d9..f386104 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -34,7 +34,7 @@
   """Test cases for stateless random-number generator operators."""
 
   def _random_types(self):
-    return [dtypes.float32]
+    return self.float_types & {dtypes.float32, dtypes.float64}
 
   def testDeterminism(self):
     # Stateless values should be equal iff the seeds are equal (roughly)
@@ -124,8 +124,7 @@
         self.assertTrue(self._anderson_darling(y) < 2.492)
 
   def testTruncatedNormalIsInRange(self):
-    # TODO(b/34339814): implement inverse erf support for non-F32 types.
-    for dtype in [dtypes.float32]:
+    for dtype in self._random_types():
       with self.cached_session() as sess, self.test_scope():
         seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
         n = 10000000
@@ -159,7 +158,7 @@
         # Department of Scientific Computing website. Florida State University.
         expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
         actual_mean = np.mean(y)
-        self.assertAllClose(actual_mean, expected_mean, atol=2e-4)
+        self.assertAllClose(actual_mean, expected_mean, atol=5e-4)
 
         expected_median = mu + probit(
             (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index 55a9921..98a0770 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -122,8 +122,7 @@
           expected=np.array([[2], [5]], dtype=dtype))
 
   def testClipByValue(self):
-    # TODO(b/78258593): enable integer types here too.
-    for dtype in self.float_types:
+    for dtype in self.numeric_types - self.complex_types:
       test_cases = [
           (np.array([2, 4, 5], dtype=dtype), dtype(7)),  #
           (dtype(1), np.array([2, 4, 5], dtype=dtype)),  #
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 5b0e57f..77f6eee 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -84,7 +84,7 @@
       self.assertAllClose(result[i], expected[i], rtol, atol)
 
   def testAllTypeOps(self):
-    for dtype in self.numeric_types:
+    for dtype in self.numeric_types - {np.int8, np.uint8}:
       self._assertOpOutputMatchesExpected(
           array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype),
           np.array(
@@ -158,9 +158,6 @@
 
   def testFloatOps(self):
     for dtype in self.float_types:
-      # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018.
-      if dtype == np.float16 and self.device == "XLA_CPU":
-        continue
       x = np.arange(-0.90, 0.90, 0.25)
       self._assertOpOutputMatchesExpected(
           math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype))
@@ -633,7 +630,7 @@
           expected=np.array([-1, 0, -2, -17, -43], dtype=dtype))
 
   def testNumericOps(self):
-    for dtype in self.numeric_types:
+    for dtype in self.numeric_types - {np.int8, np.uint8}:
       self._assertOpOutputMatchesExpected(
           math_ops.abs,
           np.array([[2, -1]], dtype=dtype),
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 1e600c4..4cf88fc 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -181,7 +181,7 @@
               dtype=dtype))
 
   def testNeg(self):
-    for dtype in self.numeric_types:
+    for dtype in self.numeric_types - {np.uint8, np.int8}:
       self._assertOpOutputMatchesExpected(
           xla.neg,
           args=(np.array([1, 2, 3], dtype=dtype),),
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index 88827cb..98a4198 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -97,10 +97,23 @@
     ])
     self._numeric_tf_types = set(
         self.int_tf_types | self._float_tf_types | self.complex_tf_types)
+    self.quantized_tf_types = set(
+        dtype for dtype in self._all_tf_types if dtype.is_quantized)
 
-    self._all_types = set(
-        [dtype.as_numpy_dtype for dtype in self._all_tf_types])
+    # Quantized types don't have a numpy equivalent, include them in
+    # all_tf_types but not in all_types.
+    # TODO(b/115960798): Parametrize tests on TF types instead of numpy types
+    # and remove all_types.
+    self._all_types = set(dtype.as_numpy_dtype
+                          for dtype in self._all_tf_types
+                          if not dtype.is_quantized)
     self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types])
+    self.signed_int_types = set(dtype.as_numpy_dtype
+                                for dtype in self.int_tf_types
+                                if not dtype.is_unsigned)
+    self.unsigned_int_types = set(dtype.as_numpy_dtype
+                                  for dtype in self.int_tf_types
+                                  if dtype.is_unsigned)
     self._float_types = set(
         [dtype.as_numpy_dtype for dtype in self._float_tf_types])
     self.complex_types = set([
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index d549e7b..ba1e3b2 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -611,6 +611,7 @@
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
     ],
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index 922ae7c..027ca6d 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -29,14 +29,6 @@
                               std::vector<bool>* compile_time_const_arg_indices,
                               std::vector<bool>* compile_time_const_nodes,
                               std::function<bool(const Edge&)> edge_filter) {
-  // Operators that don't look at the data of their inputs, just the shapes.
-  const std::unordered_set<string> metadata_ops = {
-      "Rank",
-      "Shape",
-      "ShapeN",
-      "Size",
-  };
-
   std::vector<bool> compile_time_const_nodes_impl;
   if (compile_time_const_nodes) {
     CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
@@ -50,7 +42,9 @@
     if (!status.ok()) return;
 
     // If this is a metadata-only op, don't propagate the const requirement.
-    if (metadata_ops.find(node->type_string()) != metadata_ops.end()) return;
+    if (XlaOpRegistry::IsMetadataOp(node->type_string())) {
+      return;
+    }
 
     // If this node must be const, and it isn't a metadata op, then all of its
     // parents must be const.
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h
index ab7cac7..e9f0220 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.h
+++ b/tensorflow/compiler/tf2xla/graph_compiler.h
@@ -55,17 +55,17 @@
 // op registration infrastructure instead of FunctionLibraryRuntime.
 class GraphCompiler {
  public:
-  GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device,
-                Graph* graph, FunctionLibraryRuntime* flib,
+  GraphCompiler(XlaCompilationDevice* device, Graph* graph,
+                FunctionLibraryRuntime* flib,
                 ScopedStepContainer* step_container)
-      : xla_context_(xla_context),
-        device_(device),
+      : device_(device),
         graph_(graph),
         flib_(flib),
         step_container_(step_container) {}
 
-  // Compiles the graph. The results are written in `xla_context` that is passed
-  // into the compiler.
+  // Compiles the graph. The results are written in xla_context stored in the
+  // resource_manager of the 'XlaCompilationDevice' that's passed into the
+  // constructor.
   Status Compile();
 
  private:
@@ -82,7 +82,6 @@
   // using `compiler_`.
   Status CompileFunctionalNode(Node* n, OpKernelContext* op_context);
 
-  XlaContext* xla_context_;
   XlaCompilationDevice* device_;
   Graph* graph_;
   FunctionLibraryRuntime* flib_;
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 46794f7..3e82325 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -113,6 +113,7 @@
         "shape_util.h",
     ],
     deps = [
+        ":conv_op_helpers",
         ":if_op",
         ":while_op",
         "//tensorflow/compiler/tf2xla:common",
@@ -172,6 +173,27 @@
     ],
 )
 
+cc_library(
+    name = "conv_op_helpers",
+    srcs = ["conv_op_helpers.cc"],
+    hdrs = ["conv_op_helpers.h"],
+    deps = [
+        "//tensorflow/compiler/tf2xla:common",
+        "//tensorflow/compiler/tf2xla:xla_compiler",
+        "//tensorflow/compiler/xla:literal_util",
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/compiler/xla/client/lib:arithmetic",
+        "//tensorflow/compiler/xla/client/lib:constants",
+        "//tensorflow/compiler/xla/client/lib:numeric",
+        "//tensorflow/core:framework",
+        "//tensorflow/core/kernels:bounds_check",
+        "//tensorflow/core/kernels:conv_ops",
+        "//tensorflow/core/kernels:ops_util",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
 tf_kernel_library(
     name = "while_op",
     srcs = ["while_op.cc"],
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index df17da4..0d9a768 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -66,6 +66,9 @@
 static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
                                xla::XlaOp y, const BCast& broadcast_helper) {
   std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+  if (DataTypeIsUnsigned(dtype)) {
+    return xla::Div(x, y);
+  }
   auto zero = XlaHelpers::Zero(b, dtype);
   auto one = XlaHelpers::One(b, dtype);
   auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero));
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index f410605..0ae23aa 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -37,6 +37,16 @@
 namespace tensorflow {
 namespace {
 
+// Used to determine the number of Tensors allowed in a Concat op to prevent
+// going over the max gpu parameter memory size. This is an issue because concat
+// is variadic and can have an unlimited number of arguments when called.
+// Concat ops with more Tensors than this will be split into multiple concat
+// ops.
+//
+// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass
+// along with boxing large numbers of parameters.
+constexpr int64 kMaxConcatArgsPerOp = 500;
+
 // --------------------------------------------------------------------------
 class ConcatBaseOp : public XlaOpKernel {
  public:
@@ -74,6 +84,7 @@
     // Make a vector holding the XlaOp for each of the inputs that has non-zero
     // elements.
     std::vector<xla::XlaOp> input_data;
+    std::vector<xla::XlaOp> partial_concats;
     int output_concat_dim = 0;
     const bool input_is_scalar = IsLegacyScalar(input_shape);
     for (int i = 0; i < N; ++i) {
@@ -94,10 +105,30 @@
         input_data.push_back(handle);
       }
       output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1;
+
+      // Concat is associative, so it can be split into many operations when too
+      // many arguments are in a single op. This is a temporary workaround for
+      // b/112613927 where too many parameters in an XlaLaunchOp later result in
+      // too many parameters to a single GPU kernel.
+      if (i && i % kMaxConcatArgsPerOp == 0) {
+        partial_concats.push_back(
+            xla::ConcatInDim(ctx->builder(), input_data, axis));
+        input_data.clear();
+      }
     }
+    // Add any inputs that have not been put into another concat yet.
+    partial_concats.insert(partial_concats.end(), input_data.begin(),
+                           input_data.end());
 
     VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis;
-    ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis));
+    // Don't add an additional "identity" concatenate for better readibility of
+    // IR.
+    if (partial_concats.size() == 1) {
+      ctx->SetOutput(0, partial_concats.front());
+    } else {
+      ctx->SetOutput(0,
+                     xla::ConcatInDim(ctx->builder(), partial_concats, axis));
+    }
   }
 
  private:
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
new file mode 100644
index 0000000..c9a1be4
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
@@ -0,0 +1,509 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for 2D convolution.
+
+#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+namespace {
+
+// Returns the expanded size of a filter used for depthwise convolution.
+// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
+xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) {
+  int num_dims = shape.dimensions_size();
+  CHECK_GE(num_dims, 2);  // Crash OK
+  xla::Shape expanded_shape = shape;
+  expanded_shape.set_dimensions(
+      num_dims - 1,
+      shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1));
+  return expanded_shape;
+}
+
+// Create a mask for depthwise convolution that will make a normal convolution
+// produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
+// depthwise filter this returns a [2, 2, 3, 6] tensor
+//   1 1 0 0 0 0   1 1 0 0 0 0
+//   0 0 1 1 0 0   0 0 1 1 0 0
+//   0 0 0 0 1 1   0 0 0 0 1 1
+//
+//   1 1 0 0 0 0   1 1 0 0 0 0
+//   0 0 1 1 0 0   0 0 1 1 0 0
+//   0 0 0 0 1 1   0 0 0 0 1 1
+//
+// The first step is to create a one tensor, A, that is [3]
+//   0 1 2
+//
+// and another tensor, B,  that is [3 * 2]
+//   0 1 2 3 4 5
+//
+// and divide B it by 2 to get
+//   0 0 1 1 2 2
+//
+// then we broadcast the B to [2, 2, 3, 3 * 2]
+//   0 0 1 1 2 2   0 0 1 1 2 2
+//   0 0 1 1 2 2   0 0 1 1 2 2
+//   0 0 1 1 2 2   0 0 1 1 2 2
+//
+//   0 0 1 1 2 2   0 0 1 1 2 2
+//   0 0 1 1 2 2   0 0 1 1 2 2
+//   0 0 1 1 2 2   0 0 1 1 2 2
+//
+// Finally compare A and broadcasted B in dimension 2 amd return the result at
+// the beginning of the comment.
+xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape,
+                                    xla::XlaBuilder* builder) {
+  xla::Shape expanded_filter_shape =
+      ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+  int64 depthwise_multiplier =
+      filter_shape.dimensions(filter_shape.dimensions_size() - 1);
+  int64 input_feature =
+      filter_shape.dimensions(filter_shape.dimensions_size() - 2);
+
+  // Create a M sized linspace and an M*N sized linspace that will be
+  // broadcasted into perpendicular dimensions and compared.
+  xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
+  xla::XlaOp expanded_feature_iota =
+      xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
+
+  // Divide the M*N sized linspace by the depthwise_multiplier to create
+  // [0 0 1 1 2 2] in the example in the function comment.
+  expanded_feature_iota =
+      xla::Div(expanded_feature_iota,
+               XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
+                                          depthwise_multiplier));
+
+  // Broadcast the N*M linspace to [H, W, ..., M, M*N].
+  std::vector<int64> expanded_feature_broadcast_dims(
+      expanded_filter_shape.dimensions().begin(),
+      expanded_filter_shape.dimensions().end());
+  expanded_feature_broadcast_dims.pop_back();
+  auto broadcasted_expanded_feature_iota =
+      xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
+
+  // Compare the broadcasted linspace to the input feature linspace in the
+  // input feature dimension to create a diagonal predicate.
+  return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
+                 {expanded_filter_shape.dimensions_size() - 2});
+}
+
+// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
+// build a depthwise convolution.
+xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
+                                                const xla::XlaOp& filter) {
+  int64 input_feature_dim = filter_shape.dimensions_size() - 2;
+  int64 output_feature_dim = filter_shape.dimensions_size() - 1;
+  int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
+  int64 input_feature = filter_shape.dimensions(input_feature_dim);
+
+  // Create a [H, W, ..., 1, N*M] reshape of the filter.
+  xla::Shape implicit_broadcast_filter_shape = filter_shape;
+  implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1);
+  implicit_broadcast_filter_shape.set_dimensions(
+      output_feature_dim, depthwise_multiplier * input_feature);
+  return xla::Reshape(
+      filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions()));
+}
+
+// Reduces the results of the convolution with an expanded filter to the
+// non-expanded filter.
+xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape,
+                                              const xla::XlaOp& filter_backprop,
+                                              xla::XlaBuilder* builder) {
+  auto masked_expanded_filter =
+      xla::Select(CreateExpandedFilterMask(filter_shape, builder),
+                  filter_backprop, xla::ZerosLike(filter_backprop));
+
+  auto elem_type = filter_shape.element_type();
+  return xla::Reshape(
+      // This reduce does not need inputs to be converted with
+      // XlaHelpers::SumAccumulationType() since the select above guarantees
+      // that only one element is non zero, so there cannot be accumulated
+      // precision error.
+      xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type),
+                  CreateScalarAddComputation(elem_type, builder),
+                  {filter_shape.dimensions_size() - 2}),
+      xla::AsInt64Slice(filter_shape.dimensions()));
+}
+
+// Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA
+// convolutions (as currently implemented).
+Status CheckConvAttrs(const ConvOpAttrs& attrs) {
+  const int num_dims = attrs.num_spatial_dims + 2;
+  if (attrs.strides.size() != num_dims) {
+    return errors::InvalidArgument("Sliding window strides field must specify ",
+                                   num_dims, " dimensions");
+  }
+  int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+  int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+  if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) {
+    return errors::Unimplemented(
+        "Current implementation does not yet support strides in the batch and "
+        "depth dimensions.");
+  }
+  if (attrs.dilations.size() != num_dims) {
+    return errors::InvalidArgument("Dilations field must specify ", num_dims,
+                                   " dimensions");
+  }
+  if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) {
+    return errors::Unimplemented(
+        "Current implementation does not support dilations in the batch and "
+        "depth dimensions.");
+  }
+  for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+    int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+    if (attrs.dilations[input_dim] < 1) {
+      return errors::Unimplemented("Dilation values must be positive; ", i,
+                                   "th spatial dimension had dilation ",
+                                   attrs.dilations[input_dim]);
+    }
+  }
+  return Status::OK();
+}
+
+// Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes
+// to TensorShapes.
+Status ConvBackpropComputeDimensionsV2XlaShapes(
+    StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
+    const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
+    absl::Span<const int32> dilations, const std::vector<int32>& strides,
+    Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) {
+  TensorShape input_tensor_shape, filter_tensor_shape,
+      out_backprop_tensor_shape;
+  TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
+  TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
+  TF_RETURN_IF_ERROR(
+      XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
+  return ConvBackpropComputeDimensionsV2(
+      label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
+      out_backprop_tensor_shape, dilations, strides, padding, data_format,
+      dims);
+}
+
+}  // anonymous namespace
+
+xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
+                                               bool depthwise,
+                                               OpKernelConstruction* ctx) {
+  ConvOpAttrs attrs;
+  attrs.num_spatial_dims = num_spatial_dims;
+  attrs.depthwise = depthwise;
+  TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
+  TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
+  TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
+
+  string data_format;
+  TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
+  if (!FormatFromString(data_format, &attrs.data_format)) {
+    return errors::InvalidArgument("Invalid data format: ", data_format);
+  }
+
+  return attrs;
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
+                                               xla::XlaOp conv_input,
+                                               xla::XlaOp filter,
+                                               const ConvOpAttrs& attrs) {
+  TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+  auto* builder = conv_input.builder();
+  TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input));
+  // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
+  TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
+
+  // For 2D convolution, there should be 4 dimensions.
+  int num_dims = attrs.num_spatial_dims + 2;
+  if (input_shape.dimensions_size() != num_dims) {
+    return errors::InvalidArgument("input must be ", num_dims, "-dimensional",
+                                   input_shape.DebugString());
+  }
+  if (filter_shape.dimensions_size() != num_dims) {
+    return errors::InvalidArgument(
+        "filter must be ", num_dims,
+        "-dimensional: ", filter_shape.DebugString());
+  }
+
+  // The last two dimensions of the filter are the input and output shapes.
+  int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+  int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+  int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims);
+  // The 'C' dimension for input is in_depth. It must be the same as
+  // the filter's in_depth.
+  if (in_depth != input_shape.dimensions(feature_dim)) {
+    return errors::InvalidArgument(
+        "input and filter must have the same depth: ", in_depth, " vs ",
+        input_shape.dimensions(feature_dim));
+  }
+
+  if (attrs.depthwise) {
+    filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
+  }
+
+  xla::ConvolutionDimensionNumbers dims;
+  std::vector<int64> window_strides(attrs.num_spatial_dims);
+  std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1);
+  std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+  std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+
+  dims.set_input_batch_dimension(batch_dim);
+  dims.set_output_batch_dimension(batch_dim);
+  dims.set_input_feature_dimension(feature_dim);
+  dims.set_output_feature_dimension(feature_dim);
+  dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
+  dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
+
+  for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+    const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+    dims.add_input_spatial_dimensions(dim);
+    dims.add_kernel_spatial_dimensions(i);
+    dims.add_output_spatial_dimensions(dim);
+    window_strides[i] = attrs.strides.at(dim);
+    rhs_dilation[i] = attrs.dilations.at(dim);
+
+    int64 unused_output_size;
+    TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
+        input_shape.dimensions(dim), filter_shape.dimensions(i),
+        rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size,
+        &padding[i].first, &padding[i].second));
+  }
+
+  return xla::ConvGeneralDilated(
+      conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
+      dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1);
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
+    StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
+    xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
+  TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+  int num_dims = attrs.num_spatial_dims + 2;
+  int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+  int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+  auto* builder = filter.builder();
+  TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
+  TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
+                      builder->GetShape(out_backprop));
+
+  xla::Shape expanded_filter_shape =
+      attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
+                      : filter_shape;
+  // Reuse dimension computation logic from conv_grad_ops.cc.
+  ConvBackpropDimensions dims;
+  TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+      type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
+      out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
+      attrs.data_format, &dims));
+
+  // The input gradients are computed by a convolution of the output
+  // gradients and the filter, with some appropriate padding. See the
+  // comment at the top of conv_grad_ops.h for details.
+
+  xla::ConvolutionDimensionNumbers dnums;
+  dnums.set_input_batch_dimension(batch_dim);
+  dnums.set_output_batch_dimension(batch_dim);
+  dnums.set_input_feature_dimension(feature_dim);
+  dnums.set_output_feature_dimension(feature_dim);
+
+  // TF filter shape is [ H, W, ..., inC, outC ]
+  // Transpose the input and output features for computing the gradient.
+  dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1);
+  dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims);
+
+  std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims);
+  std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+  std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
+  std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+  std::vector<int64> ones(attrs.num_spatial_dims, 1);
+  for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+    int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+    dnums.add_input_spatial_dimensions(dim);
+    dnums.add_kernel_spatial_dimensions(i);
+    dnums.add_output_spatial_dimensions(dim);
+
+    kernel_spatial_dims[i] = i;
+    padding[i] = {dims.spatial_dims[i].pad_before,
+                  dims.spatial_dims[i].pad_after};
+    lhs_dilation[i] = dims.spatial_dims[i].stride;
+    rhs_dilation[i] = attrs.dilations[dim];
+  }
+
+  // Mirror the filter in the spatial dimensions.
+  xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
+
+  // activation gradients
+  //   = gradients (with padding and dilation) <conv> mirrored_weights
+  return xla::ConvGeneralDilated(
+      out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
+      lhs_dilation, rhs_dilation, dnums,
+      /*feature_group_count=*/
+      attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
+                            filter_shape.dimensions(attrs.num_spatial_dims + 1)
+                      : 1);
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
+    StringPiece type_string, xla::XlaOp activations,
+    const xla::Shape& filter_shape, xla::XlaOp gradients,
+    const ConvOpAttrs& attrs) {
+  TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+  auto* builder = activations.builder();
+  TF_ASSIGN_OR_RETURN(xla::Shape activations_shape,
+                      builder->GetShape(activations));
+  TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
+                      builder->GetShape(gradients));
+  const xla::Shape expanded_filter_shape =
+      attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
+                      : filter_shape;
+
+  // Reuse dimension computation logic from conv_grad_ops.cc.
+  ConvBackpropDimensions dims;
+  TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+      type_string, attrs.num_spatial_dims, activations_shape,
+      expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
+      attrs.padding, attrs.data_format, &dims));
+
+  // The filter gradients are computed by a convolution of the input
+  // activations and the output gradients, with some appropriate padding.
+  // See the comment at the top of conv_grad_ops.h for details.
+
+  xla::ConvolutionDimensionNumbers dnums;
+
+  // The activations (inputs) form the LHS of the convolution.
+  // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
+  // For the gradient computation, we flip the roles of the batch and
+  // feature dimensions.
+  // Each spatial entry has size in_depth * batch
+
+  // The last two dimensions of the filter are the input and output shapes.
+  int num_dims = attrs.num_spatial_dims + 2;
+  int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+  int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+  // Swap n_dim and c_dim in the activations.
+  dnums.set_input_batch_dimension(c_dim);
+  dnums.set_input_feature_dimension(n_dim);
+
+  // The gradients become the RHS of the convolution.
+  // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
+  // where the batch becomes the input feature for the convolution.
+  dnums.set_kernel_input_feature_dimension(n_dim);
+  dnums.set_kernel_output_feature_dimension(c_dim);
+
+  std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+  std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+  std::vector<int64> window_strides(attrs.num_spatial_dims);
+  std::vector<int64> ones(attrs.num_spatial_dims, 1);
+
+  // Tensorflow filter shape is [ H, W, ..., inC, outC ].
+  for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+    dnums.add_output_spatial_dimensions(i);
+  }
+  dnums.set_output_batch_dimension(attrs.num_spatial_dims);
+  dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1);
+
+  for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+    int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+    dnums.add_input_spatial_dimensions(dim);
+    dnums.add_kernel_spatial_dimensions(dim);
+
+    // We will also need to pad the input with zeros such that after the
+    // convolution, we get the right size for the filter.
+    // The padded_in_rows should be such that when we convolve this with the
+    // expanded_out_rows as a filter, we should get filter_rows back.
+    //
+    const int64 padded_in_size =
+        dims.spatial_dims[i].expanded_output_size +
+        (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim];
+
+    // However it can be smaller than input_rows: in this
+    // case it means some of the inputs are not used.
+    //
+    // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
+    //
+    // INPUT =  [ A  B  C ]
+    //
+    // FILTER = [ x y ]
+    //
+    // and the output will only have one column: a = A * x + B * y
+    //
+    // and input "C" is not used at all.
+    //
+    // We apply negative padding in this case.
+    const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
+
+    // + For the VALID padding, we don't pad anything on the top/left side
+    //   and pad the bottom/right side with the remaining space.
+    // + For the SAME padding, we pad top/left side the same as bottom/right
+    //   side.
+    //
+    // In addition, if the padded input size is smaller than the input size,
+    // we need to ignore some training elements of the input. We do this by
+    // applying negative padding on the right/bottom.
+    const int64 pad_before =
+        attrs.padding == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
+
+    padding[i] = {pad_before, pad_total - pad_before};
+    rhs_dilation[i] = dims.spatial_dims[i].stride;
+    window_strides[i] = attrs.dilations[dim];
+  }
+
+  // Besides padding the input, we will also expand output_rows to
+  //    expanded_out_rows = (output_rows - 1) * stride + 1
+  // with zeros in between:
+  //
+  //      a . . . b . . . c . . . d . . . e
+  //
+  // This is done by specifying the window dilation factors in the
+  // convolution HLO below.
+  auto filter_backprop =
+      xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
+                              /*lhs_dilation=*/ones, rhs_dilation, dnums);
+
+  if (attrs.depthwise) {
+    filter_backprop = ContractFilterForDepthwiseBackprop(
+        filter_shape, filter_backprop, activations.builder());
+  }
+
+  return filter_backprop;
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
new file mode 100644
index 0000000..6e1b70a
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
@@ -0,0 +1,69 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
+#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+// This header exposes utilities for translating TensorFlow convolution ops into
+// XLA ops.
+//
+// conv_ops.cc contains lowerings for many of these TF convolution ops (e.g.
+// Conv2D, Conv3DBackpropFilterV2), but you might want to use the utilities in
+// this header to implement a new and exciting convolution op, for example a
+// fused TensorFlow op that contains a convolution and other things.
+
+namespace tensorflow {
+
+// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA
+// convolution.
+struct ConvOpAttrs {
+  // Constructs a ConvOpAttrs, reading most of the attributes from `ctx`.
+  static xla::StatusOr<ConvOpAttrs> Create(int num_spatial_dims, bool depthwise,
+                                           OpKernelConstruction* ctx);
+
+  bool depthwise;
+  int num_spatial_dims;
+  std::vector<int32> dilations;
+  std::vector<int32> strides;
+  Padding padding;
+  TensorFormat data_format;
+};
+
+// Creates a new XLA forward or backward convolution with the given inputs and
+// attributes.
+xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece type_string,
+                                               xla::XlaOp conv_input,
+                                               xla::XlaOp filter,
+                                               const ConvOpAttrs& attrs);
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
+    StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
+    xla::XlaOp out_backprop, const ConvOpAttrs& attrs);
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
+    StringPiece type_string, xla::XlaOp activations,
+    const xla::Shape& filter_shape, xla::XlaOp gradients,
+    const ConvOpAttrs& attrs);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 674720e..cd7c820 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -15,12 +15,17 @@
 
 // XLA-specific Ops for 2D convolution.
 
+#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
 #include "tensorflow/compiler/xla/client/lib/numeric.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/numeric_op.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -33,250 +38,28 @@
 #include "tensorflow/core/util/tensor_format.h"
 
 namespace tensorflow {
-
 namespace {
 
-// Returns the expanded size of a filter used for depthwise convolution.
-// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
-TensorShape ExpandedFilterShapeForDepthwiseConvolution(
-    const TensorShape& shape) {
-  int num_dims = shape.dims();
-  CHECK_GE(num_dims, 2);
-  TensorShape expanded_shape = shape;
-  expanded_shape.set_dim(num_dims - 1, shape.dim_size(num_dims - 2) *
-                                           shape.dim_size(num_dims - 1));
-  return expanded_shape;
-}
-
-// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution.
-xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype,
-                              xla::XlaBuilder* builder) {
-  TensorShape expanded_filter_shape =
-      ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
-  return xla::Broadcast(XlaHelpers::Zero(builder, dtype),
-                        expanded_filter_shape.dim_sizes());
-}
-
-// Create a mask for depthwise convolution that will make a normal convolution
-// produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
-// depthwise filter this returns a [2, 2, 3, 6] tensor
-//   1 1 0 0 0 0   1 1 0 0 0 0
-//   0 0 1 1 0 0   0 0 1 1 0 0
-//   0 0 0 0 1 1   0 0 0 0 1 1
-//
-//   1 1 0 0 0 0   1 1 0 0 0 0
-//   0 0 1 1 0 0   0 0 1 1 0 0
-//   0 0 0 0 1 1   0 0 0 0 1 1
-//
-// The first step is to create a one tensor, A, that is [3]
-//   0 1 2
-//
-// and another tensor, B,  that is [3 * 2]
-//   0 1 2 3 4 5
-//
-// and divide B it by 2 to get
-//   0 0 1 1 2 2
-//
-// then we broadcast the B to [2, 2, 3, 3 * 2]
-//   0 0 1 1 2 2   0 0 1 1 2 2
-//   0 0 1 1 2 2   0 0 1 1 2 2
-//   0 0 1 1 2 2   0 0 1 1 2 2
-//
-//   0 0 1 1 2 2   0 0 1 1 2 2
-//   0 0 1 1 2 2   0 0 1 1 2 2
-//   0 0 1 1 2 2   0 0 1 1 2 2
-//
-// Finally compare A and broadcasted B in dimension 2 amd return the result at
-// the beginning of the comment.
-xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape,
-                                    xla::XlaBuilder* builder) {
-  TensorShape expanded_filter_shape =
-      ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
-  int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
-  int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
-
-  // Create a M sized linspace and an M*N sized linspace that will be
-  // broadcasted into perpendicular dimensions and compared.
-  xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
-  xla::XlaOp expanded_feature_iota =
-      xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
-
-  // Divide the M*N sized linspace by the depthwise_multiplier to create
-  // [0 0 1 1 2 2] in the example in the function comment.
-  expanded_feature_iota =
-      xla::Div(expanded_feature_iota,
-               XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
-                                          depthwise_multiplier));
-
-  // Broadcast the N*M linspace to [H, W, ..., M, M*N].
-  auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes();
-  expanded_feature_broadcast_dims.pop_back();
-  auto broadcasted_expanded_feature_iota =
-      xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
-
-  // Compare the broadcasted linspace to the input feature linspace in the
-  // input feature dimension to create a diagonal predicate.
-  return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
-                 {expanded_filter_shape.dims() - 2});
-}
-
-// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
-// build a depthwise convolution.
-xla::XlaOp ReshapeFilterForDepthwiseConvolution(const TensorShape& filter_shape,
-                                                const xla::XlaOp& filter) {
-  int64 input_feature_dim = filter_shape.dims() - 2;
-  int64 output_feature_dim = filter_shape.dims() - 1;
-  int64 depthwise_multiplier = filter_shape.dim_size(output_feature_dim);
-  int64 input_feature = filter_shape.dim_size(input_feature_dim);
-
-  // Create a [H, W, ..., 1, N*M] reshape of the filter.
-  TensorShape implicit_broadcast_filter_shape = filter_shape;
-  implicit_broadcast_filter_shape.set_dim(input_feature_dim, 1);
-  implicit_broadcast_filter_shape.set_dim(output_feature_dim,
-                                          depthwise_multiplier * input_feature);
-  return xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
-}
-
-// Reduces the results of the convolution with an expanded filter to the
-// non-expanded filter.
-xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx,
-                                              const TensorShape& filter_shape,
-                                              DataType dtype,
-                                              const xla::XlaOp& filter_backprop,
-                                              xla::XlaBuilder* builder) {
-  auto masked_expanded_filter = xla::Select(
-      CreateExpandedFilterMask(filter_shape, builder), filter_backprop,
-      CreateExpandedZero(filter_shape, dtype, builder));
-  return xla::Reshape(
-      // This reduce does not need inputs to be converted with
-      // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with
-      // ExpandedZero guarantees that only one element is non zero, so there
-      // cannot be accumulated precision error.
-      xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
-                  *ctx->GetOrCreateAdd(dtype), {filter_shape.dims() - 2}),
-      filter_shape.dim_sizes());
-}
-
 class ConvOp : public XlaOpKernel {
  public:
   explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims,
                   bool depthwise)
-      : XlaOpKernel(ctx),
-        num_spatial_dims_(num_spatial_dims),
-        depthwise_(depthwise) {
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
-
-    string data_format;
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
-    OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
-                errors::InvalidArgument("Invalid data format"));
+      : XlaOpKernel(ctx) {
+    xla::StatusOr<ConvOpAttrs> attrs =
+        ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+    OP_REQUIRES_OK(ctx, attrs.status());
+    attrs_ = attrs.ValueOrDie();
   }
 
-  int num_dims() const { return num_spatial_dims_ + 2; }
-
   void Compile(XlaOpKernelContext* ctx) override {
-    OP_REQUIRES(ctx, strides_.size() == num_dims(),
-                errors::InvalidArgument("Sliding window strides field must "
-                                        "specify ",
-                                        num_dims(), " dimensions"));
-    int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
-    int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
-    OP_REQUIRES(
-        ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
-        errors::Unimplemented("Current implementation does not yet support "
-                              "strides in the batch and depth dimensions."));
-
-    OP_REQUIRES(ctx, dilations_.size() == num_dims(),
-                errors::InvalidArgument("Dilations field must "
-                                        "specify ",
-                                        num_dims(), " dimensions"));
-    OP_REQUIRES(
-        ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
-        errors::Unimplemented("Current implementation does not support "
-                              "dilations in the batch and depth dimensions."));
-    for (int i = 0; i < num_spatial_dims_; ++i) {
-      int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
-      OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
-                  errors::Unimplemented("Dilation values must be positive; ", i,
-                                        "th spatial dimension had dilation ",
-                                        dilations_[input_dim]));
-    }
-
-    const TensorShape input_shape = ctx->InputShape(0);
-    // Input filter is of the following dimensions:
-    // [ filter_rows, filter_cols, ..., in_depth, out_depth]
-    const TensorShape filter_shape = ctx->InputShape(1);
-
-    // For 2D convolution, there should be 4 dimensions.
-    OP_REQUIRES(
-        ctx, input_shape.dims() == num_dims(),
-        errors::InvalidArgument("input must be ", num_dims(), "-dimensional",
-                                input_shape.DebugString()));
-    OP_REQUIRES(
-        ctx, filter_shape.dims() == num_dims(),
-        errors::InvalidArgument("filter must be ", num_dims(),
-                                "-dimensional: ", filter_shape.DebugString()));
-
-    // The last two dimension of the filter are the input and output shapes.
-    const int64 in_depth = filter_shape.dim_size(num_spatial_dims_);
-
-    // The 'C' dimension for input is in_depth. It must be the same as
-    // the filter's in_depth.
-    OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim),
-                errors::InvalidArgument(
-                    "input and filter must have the same depth: ", in_depth,
-                    " vs ", input_shape.dim_size(feature_dim)));
-
-    xla::XlaOp filter = ctx->Input(1);
-    if (depthwise_) {
-      filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
-    }
-
-    xla::ConvolutionDimensionNumbers dims;
-    std::vector<int64> window_strides(num_spatial_dims_);
-    std::vector<int64> lhs_dilation(num_spatial_dims_, 1);
-    std::vector<int64> rhs_dilation(num_spatial_dims_);
-    std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
-
-    dims.set_input_batch_dimension(batch_dim);
-    dims.set_output_batch_dimension(batch_dim);
-    dims.set_input_feature_dimension(feature_dim);
-    dims.set_output_feature_dimension(feature_dim);
-    dims.set_kernel_input_feature_dimension(num_spatial_dims_);
-    dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1);
-
-    for (int i = 0; i < num_spatial_dims_; ++i) {
-      const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
-      dims.add_input_spatial_dimensions(dim);
-      dims.add_kernel_spatial_dimensions(i);
-      dims.add_output_spatial_dimensions(dim);
-      window_strides[i] = strides_.at(dim);
-      rhs_dilation[i] = dilations_.at(dim);
-
-      int64 unused_output_size;
-      OP_REQUIRES_OK(
-          ctx, GetWindowedOutputSizeVerboseV2(
-                   input_shape.dim_size(dim), filter_shape.dim_size(i),
-                   rhs_dilation[i], window_strides[i], padding_,
-                   &unused_output_size, &padding[i].first, &padding[i].second));
-    }
-
-    xla::XlaOp conv = xla::ConvGeneralDilated(
-        ctx->Input(0), filter, window_strides, padding, lhs_dilation,
-        rhs_dilation, dims,
-        /*feature_group_count=*/depthwise_ ? in_depth : 1);
-    ctx->SetOutput(0, conv);
+    xla::StatusOr<xla::XlaOp> conv = MakeXlaForwardConvOp(
+        ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_);
+    OP_REQUIRES_OK(ctx, conv.status());
+    ctx->SetOutput(0, conv.ValueOrDie());
   }
 
  protected:
-  const int num_spatial_dims_;
-  const bool depthwise_;
-  std::vector<int32> dilations_;
-  std::vector<int32> strides_;
-  Padding padding_;
-  TensorFormat data_format_ = FORMAT_NHWC;
+  ConvOpAttrs attrs_;
 
  private:
   TF_DISALLOW_COPY_AND_ASSIGN(ConvOp);
@@ -308,124 +91,28 @@
  public:
   explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims,
                                bool depthwise)
-      : XlaOpKernel(ctx),
-        num_spatial_dims_(num_spatial_dims),
-        depthwise_(depthwise) {
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
-    string data_format;
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
-    OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
-                errors::InvalidArgument("Invalid data format"));
+      : XlaOpKernel(ctx) {
+    xla::StatusOr<ConvOpAttrs> attrs =
+        ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+    OP_REQUIRES_OK(ctx, attrs.status());
+    attrs_ = attrs.ValueOrDie();
   }
 
-  int num_dims() const { return num_spatial_dims_ + 2; }
-
   void Compile(XlaOpKernelContext* ctx) override {
-    OP_REQUIRES(ctx, strides_.size() == num_dims(),
-                errors::InvalidArgument("Sliding window strides field must "
-                                        "specify ",
-                                        num_dims(), " dimensions"));
-    int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
-    int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
-    OP_REQUIRES(
-        ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
-        errors::Unimplemented("Current implementation does not yet support "
-                              "strides in the batch and depth dimensions."));
+    TensorShape input_tensor_shape;
+    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape));
+    xla::Shape input_shape =
+        TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape);
 
-    OP_REQUIRES(ctx, dilations_.size() == num_dims(),
-                errors::InvalidArgument("Dilations field must "
-                                        "specify ",
-                                        num_dims(), " dimensions"));
-    OP_REQUIRES(
-        ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
-        errors::Unimplemented("Current implementation does not support "
-                              "dilations in the batch and depth dimensions."));
-    for (int i = 0; i < num_spatial_dims_; ++i) {
-      int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
-      OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
-                  errors::Unimplemented("Dilation values must be positive; ", i,
-                                        "th spatial dimension had dilation ",
-                                        dilations_[input_dim]));
-    }
-
-    TensorShape input_shape;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
-
-    const TensorShape filter_shape = ctx->InputShape(1);
-    const TensorShape out_backprop_shape = ctx->InputShape(2);
-
-    const TensorShape expanded_filter_shape =
-        depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
-                   : filter_shape;
-    // Reuse dimension computation logic from conv_grad_ops.cc.
-    ConvBackpropDimensions dims;
-    OP_REQUIRES_OK(ctx,
-                   ConvBackpropComputeDimensionsV2(
-                       type_string(), num_spatial_dims_, input_shape,
-                       expanded_filter_shape, out_backprop_shape, dilations_,
-                       strides_, padding_, data_format_, &dims));
-
-    auto filter = ctx->Input(1);
-    auto out_backprop = ctx->Input(2);
-
-    // The input gradients are computed by a convolution of the output
-    // gradients and the filter, with some appropriate padding. See the
-    // comment at the top of conv_grad_ops.h for details.
-
-    xla::ConvolutionDimensionNumbers dnums;
-    dnums.set_input_batch_dimension(batch_dim);
-    dnums.set_output_batch_dimension(batch_dim);
-    dnums.set_input_feature_dimension(feature_dim);
-    dnums.set_output_feature_dimension(feature_dim);
-
-    // TF filter shape is [ H, W, ..., inC, outC ]
-    // Transpose the input and output features for computing the gradient.
-    dnums.set_kernel_input_feature_dimension(num_spatial_dims_ + 1);
-    dnums.set_kernel_output_feature_dimension(num_spatial_dims_);
-
-    std::vector<int64> kernel_spatial_dims(num_spatial_dims_);
-    std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
-    std::vector<int64> lhs_dilation(num_spatial_dims_);
-    std::vector<int64> rhs_dilation(num_spatial_dims_);
-    std::vector<int64> ones(num_spatial_dims_, 1);
-    for (int i = 0; i < num_spatial_dims_; ++i) {
-      int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
-      dnums.add_input_spatial_dimensions(dim);
-      dnums.add_kernel_spatial_dimensions(i);
-      dnums.add_output_spatial_dimensions(dim);
-
-      kernel_spatial_dims[i] = i;
-      padding[i] = {dims.spatial_dims[i].pad_before,
-                    dims.spatial_dims[i].pad_after};
-      lhs_dilation[i] = dims.spatial_dims[i].stride;
-      rhs_dilation[i] = dilations_[dim];
-    }
-
-    // Mirror the filter in the spatial dimensions.
-    xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
-
-    // activation gradients
-    //   = gradients (with padding and dilation) <conv> mirrored_weights
-    xla::XlaOp in_backprop = xla::ConvGeneralDilated(
-        out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
-        lhs_dilation, rhs_dilation, dnums,
-        /*feature_group_count=*/
-        depthwise_ ? out_backprop_shape.dim_size(feature_dim) /
-                         filter_shape.dim_size(num_spatial_dims_ + 1)
-                   : 1);
-
-    ctx->SetOutput(0, in_backprop);
+    xla::StatusOr<xla::XlaOp> in_backprop =
+        MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape,
+                                   ctx->Input(1), ctx->Input(2), attrs_);
+    OP_REQUIRES_OK(ctx, in_backprop.status());
+    ctx->SetOutput(0, in_backprop.ValueOrDie());
   }
 
  protected:
-  const int num_spatial_dims_;
-  const bool depthwise_;
-  std::vector<int32> dilations_;
-  std::vector<int32> strides_;
-  Padding padding_;
-  TensorFormat data_format_ = FORMAT_NHWC;
+  ConvOpAttrs attrs_;
 
  private:
   TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp);
@@ -462,172 +149,28 @@
  public:
   explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims,
                                 bool depthwise)
-      : XlaOpKernel(ctx),
-        num_spatial_dims_(num_spatial_dims),
-        depthwise_(depthwise) {
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
-    string data_format;
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
-    OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
-                errors::InvalidArgument("Invalid data format"));
+      : XlaOpKernel(ctx) {
+    xla::StatusOr<ConvOpAttrs> attrs =
+        ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+    OP_REQUIRES_OK(ctx, attrs.status());
+    attrs_ = attrs.ValueOrDie();
   }
 
-  int num_dims() const { return num_spatial_dims_ + 2; }
-
   void Compile(XlaOpKernelContext* ctx) override {
-    const int n_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
-    const int c_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
+    TensorShape filter_tensor_shape;
+    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape));
+    xla::Shape filter_shape =
+        TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape);
 
-    OP_REQUIRES(
-        ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1),
-        errors::InvalidArgument("Current implementation does not yet support "
-                                "strides in the batch and depth dimensions."));
-
-    OP_REQUIRES(ctx, dilations_.size() == num_dims(),
-                errors::InvalidArgument("Dilations field must "
-                                        "specify ",
-                                        num_dims(), " dimensions"));
-    OP_REQUIRES(
-        ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1,
-        errors::Unimplemented("Current implementation does not support "
-                              "dilations in the batch and depth dimensions."));
-    for (int i = 0; i < num_spatial_dims_; ++i) {
-      int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
-      OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
-                  errors::Unimplemented("Dilation values must be positive; ", i,
-                                        "th spatial dimension had dilation ",
-                                        dilations_[input_dim]));
-    }
-
-    const TensorShape activations_shape = ctx->InputShape(0);
-    TensorShape filter_shape;
-    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape));
-    const TensorShape out_backprop_shape = ctx->InputShape(2);
-
-    const TensorShape expanded_filter_shape =
-        depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
-                   : filter_shape;
-
-    // Reuse dimension computation logic from conv_grad_ops.cc.
-    ConvBackpropDimensions dims;
-    OP_REQUIRES_OK(ctx,
-                   ConvBackpropComputeDimensionsV2(
-                       type_string(), num_spatial_dims_, activations_shape,
-                       expanded_filter_shape, out_backprop_shape, dilations_,
-                       strides_, padding_, data_format_, &dims));
-
-    xla::XlaBuilder* b = ctx->builder();
-    xla::XlaOp activations = ctx->Input(0);
-    xla::XlaOp gradients = ctx->Input(2);
-
-    // The filter gradients are computed by a convolution of the input
-    // activations and the output gradients, with some appropriate padding.
-    // See the comment at the top of conv_grad_ops.h for details.
-
-    xla::ConvolutionDimensionNumbers dnums;
-
-    // The activations (inputs) form the LHS of the convolution.
-    // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
-    // For the gradient computation, we flip the roles of the batch and
-    // feature dimensions.
-    // Each spatial entry has size in_depth * batch
-
-    // Swap n_dim and c_dim in the activations.
-    dnums.set_input_batch_dimension(c_dim);
-    dnums.set_input_feature_dimension(n_dim);
-
-    // The gradients become the RHS of the convolution.
-    // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
-    // where the batch becomes the input feature for the convolution.
-    dnums.set_kernel_input_feature_dimension(n_dim);
-    dnums.set_kernel_output_feature_dimension(c_dim);
-
-    std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
-    std::vector<int64> rhs_dilation(num_spatial_dims_);
-    std::vector<int64> window_strides(num_spatial_dims_);
-    std::vector<int64> ones(num_spatial_dims_, 1);
-
-    // Tensorflow filter shape is [ H, W, ..., inC, outC ].
-    for (int i = 0; i < num_spatial_dims_; ++i) {
-      dnums.add_output_spatial_dimensions(i);
-    }
-    dnums.set_output_batch_dimension(num_spatial_dims_);
-    dnums.set_output_feature_dimension(num_spatial_dims_ + 1);
-
-    for (int i = 0; i < num_spatial_dims_; ++i) {
-      int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
-      dnums.add_input_spatial_dimensions(dim);
-      dnums.add_kernel_spatial_dimensions(dim);
-
-      // We will also need to pad the input with zeros such that after the
-      // convolution, we get the right size for the filter.
-      // The padded_in_rows should be such that when we convolve this with the
-      // expanded_out_rows as a filter, we should get filter_rows back.
-      //
-      const int64 padded_in_size =
-          dims.spatial_dims[i].expanded_output_size +
-          (dims.spatial_dims[i].filter_size - 1) * dilations_[dim];
-
-      // However it can be smaller than input_rows: in this
-      // case it means some of the inputs are not used.
-      //
-      // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
-      //
-      // INPUT =  [ A  B  C ]
-      //
-      // FILTER = [ x y ]
-      //
-      // and the output will only have one column: a = A * x + B * y
-      //
-      // and input "C" is not used at all.
-      //
-      // We apply negative padding in this case.
-      const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
-
-      // + For the VALID padding, we don't pad anything on the top/left side
-      //   and pad the bottom/right side with the remaining space.
-      // + For the SAME padding, we pad top/left side the same as bottom/right
-      //   side.
-      //
-      // In addition, if the padded input size is smaller than the input size,
-      // we need to ignore some training elements of the input. We do this by
-      // applying negative padding on the right/bottom.
-      const int64 pad_before =
-          padding_ == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
-
-      padding[i] = {pad_before, pad_total - pad_before};
-      rhs_dilation[i] = dims.spatial_dims[i].stride;
-      window_strides[i] = dilations_[dim];
-    }
-
-    // Besides padding the input, we will also expand output_rows to
-    //    expanded_out_rows = (output_rows - 1) * stride + 1
-    // with zeros in between:
-    //
-    //      a . . . b . . . c . . . d . . . e
-    //
-    // This is done by specifying the window dilation factors in the
-    // convolution HLO below.
-    auto filter_backprop =
-        xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
-                                /*lhs_dilation=*/ones, rhs_dilation, dnums);
-
-    if (depthwise_) {
-      filter_backprop = ContractFilterForDepthwiseBackprop(
-          ctx, filter_shape, ctx->input_type(0), filter_backprop, b);
-    }
-    ctx->SetOutput(0, filter_backprop);
+    xla::StatusOr<xla::XlaOp> filter_backprop = MakeXlaBackpropFilterConvOp(
+        ctx->op_kernel().type_string(), ctx->Input(0), filter_shape,
+        ctx->Input(2), attrs_);
+    OP_REQUIRES_OK(ctx, filter_backprop.status());
+    ctx->SetOutput(0, filter_backprop.ValueOrDie());
   }
 
  protected:
-  const int num_spatial_dims_;
-  const bool depthwise_;
-  std::vector<int32> dilations_;
-  std::vector<int32> strides_;
-  Padding padding_;
-  TensorFormat data_format_ = FORMAT_NHWC;
+  ConvOpAttrs attrs_;
 
  private:
   TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index d9a0257..7b2bb4a 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -18,6 +18,7 @@
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
 #include "tensorflow/compiler/xla/client/lib/numeric.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/core/framework/kernel_def_builder.h"
@@ -132,14 +133,14 @@
 // If the 2D kernel would be very large, the 1D kernel can be applied once in
 // each dimension due to the symmetry of the kernel along all axis to reduce the
 // computational intensity.
-std::vector<float> Make1DKernel(int64 n) {
+xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) {
   std::vector<float> kernel(n * 2 - 1);
   for (int64 i = 0; i < n; ++i) {
     float v = (i + 1.0f) / n;
     kernel[i] = v;
     kernel[n * 2 - 2 - i] = v;
   }
-  return kernel;
+  return xla::ConstantR1<float>(builder, kernel);
 }
 
 // Kernels with more than 16 spatial elements are considered intense and the
@@ -149,41 +150,26 @@
 xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
                                     absl::Span<const int64> kernel_size,
                                     int64 channels) {
-  xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
+  auto depthwise_kernel = xla::Broadcast(
+      xla::Zero(builder, xla::F32),
+      {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1});
 
-  auto diag = xla::ConvertElementType(
-      xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1,
-                                             2 * kernel_size[1] - 1, channels}),
-              channels_iota, /*broadcast_dimensions=*/{2}),
-      xla::PrimitiveType::F32);
   return xla::Mul(
-      xla::Mul(diag,
-               xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
+      xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]),
                /*broadcast_dimensions=*/{1}),
-      xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
+      Make1DKernel(builder, kernel_size[0]),
       /*broadcast_dimensions=*/{0});
 }
 
 xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder,
                                          absl::Span<const int64> kernel_size,
                                          int64 channels, int64 dim) {
-  xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
-
-  auto diag = xla::ConvertElementType(
-      xla::Eq(
-          xla::Broadcast(channels_iota,
-                         {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
-                          dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}),
-          channels_iota, /*broadcast_dimensions=*/{2}),
-      xla::PrimitiveType::F32);
-  if (dim == 1) {
-    return xla::Mul(
-        diag, xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
-        /*broadcast_dimensions=*/{1});
-  }
-  return xla::Mul(diag,
-                  xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
-                  /*broadcast_dimensions=*/{0});
+  auto depthwise_kernel =
+      xla::Broadcast(xla::Zero(builder, xla::F32),
+                     {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
+                      dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1});
+  return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]),
+                  /*broadcast_dimensions=*/{dim});
 }
 
 xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
@@ -206,8 +192,8 @@
   xla::ConvolutionDimensionNumbers dimension_numbers;
   dimension_numbers.set_input_batch_dimension(0);
   dimension_numbers.set_output_batch_dimension(0);
-  dimension_numbers.set_input_feature_dimension(3);
-  dimension_numbers.set_output_feature_dimension(3);
+  dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
+  dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
   for (int i = 0; i < num_spatial_dims; ++i) {
     dimension_numbers.add_input_spatial_dimensions(1 + i);
     dimension_numbers.add_output_spatial_dimensions(1 + i);
@@ -285,7 +271,8 @@
                                 {{dims.kernel_size[0] - 1, upper_padding[0]},
                                  {dims.kernel_size[1] - 1, upper_padding[1]}},
                                 /*lhs_dilation=*/dims.kernel_size,
-                                /*rhs_dilation=*/{1, 1}, dimension_numbers);
+                                /*rhs_dilation=*/{1, 1}, dimension_numbers,
+                                /*feature_group_count=*/channels);
   } else {
     xla::XlaOp kernel0 =
         MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
@@ -294,7 +281,8 @@
         /*padding=*/
         {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
         /*lhs_dilation=*/{dims.kernel_size[0], 1},
-        /*rhs_dilation=*/{1, 1}, dimension_numbers);
+        /*rhs_dilation=*/{1, 1}, dimension_numbers,
+        /*feature_group_count=*/channels);
     xla::XlaOp kernel1 =
         MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1);
     output = xla::ConvGeneralDilated(
@@ -302,7 +290,8 @@
         /*padding=*/
         {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
         /*lhs_dilation=*/{1, dims.kernel_size[1]},
-        /*rhs_dilation=*/{1, 1}, dimension_numbers);
+        /*rhs_dilation=*/{1, 1}, dimension_numbers,
+        /*feature_group_count=*/channels);
   }
 
   // Add broadcasts to handle expanding from a size == 1 dimension to a
@@ -331,15 +320,15 @@
   xla::ConvolutionDimensionNumbers dimension_numbers;
   dimension_numbers.set_input_batch_dimension(0);
   dimension_numbers.set_output_batch_dimension(0);
-  dimension_numbers.set_input_feature_dimension(3);
-  dimension_numbers.set_output_feature_dimension(3);
+  dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
+  dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
   for (int i = 0; i < num_spatial_dims; ++i) {
-    dimension_numbers.add_input_spatial_dimensions(1 + i);
-    dimension_numbers.add_output_spatial_dimensions(1 + i);
+    dimension_numbers.add_input_spatial_dimensions(i + 1);
+    dimension_numbers.add_output_spatial_dimensions(i + 1);
     dimension_numbers.add_kernel_spatial_dimensions(i);
   }
-  dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
-  dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
+  dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
+  dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
   xla::XlaOp output;
   if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
     xla::XlaOp kernel =
@@ -362,7 +351,8 @@
         {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
          {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
         /*lhs_dilation=*/dims.stride,
-        /*rhs_dilation=*/{1, 1}, dimension_numbers);
+        /*rhs_dilation=*/{1, 1}, dimension_numbers,
+        /*feature_group_count=*/channels);
   } else {
     xla::XlaOp kernel0 =
         MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
@@ -388,14 +378,16 @@
         /*padding=*/
         {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
         /*lhs_dilation=*/{dims.stride[0], 1},
-        /*rhs_dilation=*/{1, 1}, dimension_numbers);
+        /*rhs_dilation=*/{1, 1}, dimension_numbers,
+        /*feature_group_count=*/channels);
 
     output = xla::ConvGeneralDilated(
         output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
         /*padding=*/
         {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
         /*lhs_dilation=*/{1, dims.stride[1]},
-        /*rhs_dilation=*/{1, 1}, dimension_numbers);
+        /*rhs_dilation=*/{1, 1}, dimension_numbers,
+        /*feature_group_count=*/channels);
   }
 
   // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 2e0a69b..c8a0f31 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -44,7 +44,7 @@
   DataType out_dtype_;
 };
 
-REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp);
+REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp);
 
 class ShapeNOp : public XlaOpKernel {
  public:
@@ -66,7 +66,7 @@
  private:
   DataType out_dtype_;
 };
-REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp);
+REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp);
 
 class RankOp : public XlaOpKernel {
  public:
@@ -82,7 +82,7 @@
   }
 };
 
-REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp);
+REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp);
 
 class SizeOp : public XlaOpKernel {
  public:
@@ -101,7 +101,7 @@
   }
 };
 
-REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp);
+REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp);
 
 class ExpandDimsOp : public XlaOpKernel {
  public:
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index 64f2d78..5400e88 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -100,16 +100,6 @@
     precision_proto.add_operand_precision(precision);
     precision_proto.add_operand_precision(precision);
 
-    // If there are no batch dimensions, use a regular Dot.
-    // TODO(b/69062148) Remove this code when Dot emitters can be passed
-    // dimensions to transpose directly (i.e. without requiring a Transpose
-    // HLO).
-    if (batch_dimension_numbers.empty()) {
-      auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x;
-      auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y;
-      return xla::Dot(lhs, rhs, &precision_proto);
-    }
-
     xla::DotDimensionNumbers dot_dnums;
     dot_dnums.add_lhs_contracting_dimensions(x_inner_dim);
     dot_dnums.add_rhs_contracting_dimensions(y_inner_dim);
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 0236350..733eeed 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -121,8 +121,8 @@
 DynamicSlice extracts a sub-array from the input array at dynamic
 start_indices. The size of the slice in each dimension is passed in
 size_indices, which specify the end point of exclusive slice intervals in each
-dimension -- [start, start + size). The shape of start_indices must be rank ==
-1, with dimension size equal to the rank of operand.
+dimension -- [start, start + size). The shape of start_indices must have rank 1,
+with dimension size equal to the rank of operand.
 
 input: A `Tensor` of type T.
 
@@ -131,7 +131,8 @@
 
 start_indices: List of N integers containing the slice size for each
   dimension. Each value must be strictly greater than zero, and start + size
-  must be less
+  must be less than or equal to the size of the dimension to avoid
+  implementation defined behavior.
 )doc");
 
 REGISTER_OP("XlaDynamicUpdateSlice")
diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc
index 9d19922..b589512 100644
--- a/tensorflow/compiler/tf2xla/shape_util.cc
+++ b/tensorflow/compiler/tf2xla/shape_util.cc
@@ -41,6 +41,14 @@
 // Convert a TensorShape into the equivalent XLA Shape proto.
 Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
                              xla::Shape* shape) {
+  xla::PrimitiveType type;
+  TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
+  *shape = TensorShapeToXLAShape(type, tensor_shape);
+  return Status::OK();
+}
+
+xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
+                                 const TensorShape& tensor_shape) {
   int rank = tensor_shape.dims();
   std::vector<int64> dimensions(rank);
   std::vector<int64> layout(rank);
@@ -50,11 +58,7 @@
   // XLA uses minor-to-major; Tensorflow uses major-to-minor.
   std::iota(layout.rbegin(), layout.rend(), 0);
 
-  xla::PrimitiveType type;
-  TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
-
-  *shape = xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
-  return Status::OK();
+  return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h
index 58240b9..f7e34a5 100644
--- a/tensorflow/compiler/tf2xla/shape_util.h
+++ b/tensorflow/compiler/tf2xla/shape_util.h
@@ -35,6 +35,11 @@
 Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
                              xla::Shape* shape);
 
+// Converts a TensorShape into the equivalent XLA Shape proto, taking an
+// xla::PrimitiveType to specify the element type.  This never fails.
+xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
+                                 const TensorShape& tensor_shape);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc
index 3c6c9a9..f31bfb4 100644
--- a/tensorflow/compiler/tf2xla/test_util.cc
+++ b/tensorflow/compiler/tf2xla/test_util.cc
@@ -40,4 +40,12 @@
   return Status::OK();
 }
 
+std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph) {
+  std::unordered_map<string, Node*> index;
+  for (Node* node : graph.nodes()) {
+    index[node->name()] = node;
+  }
+  return index;
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h
index e6e4ae9..350a868 100644
--- a/tensorflow/compiler/tf2xla/test_util.h
+++ b/tensorflow/compiler/tf2xla/test_util.h
@@ -24,8 +24,10 @@
 
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/equal_graph_def.h"
 
 namespace tensorflow {
 
@@ -42,6 +44,20 @@
                                   const FunctionLibraryDefinition& library,
                                   InstantiationResultForTest* result);
 
+// Builds a map from node name to Node* for `graph`.
+std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph);
+
 }  // namespace tensorflow
 
+// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for
+// equality.
+#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual)               \
+  do {                                                              \
+    string diff;                                                    \
+    EqualGraphDefOptions eq_options;                                \
+    eq_options.ignore_internal_attrs = false;                       \
+    EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \
+        << diff << "\nActual: " << SummarizeGraphDef(actual);       \
+  } while (false)
+
 #endif  // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 105f3b6..739e477 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -325,8 +325,7 @@
       step_container->name(), XlaContext::kXlaContextResourceName,
       xla_context));
 
-  GraphCompiler graph_compiler(xla_context, device, graph.get(), flib,
-                               step_container.get());
+  GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
   TF_RETURN_IF_ERROR(graph_compiler.Compile());
   // Explicitly clean up the step container, to capture the cleanup status.
   step_container.reset();
diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
index 23d04d4..bc44301 100644
--- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
+++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
@@ -20,21 +20,6 @@
 namespace tensorflow {
 
 bool CpuOpFilter(KernelDef* kdef) {
-  // TODO(b/34339814): implement inverse erf for double types and remove this
-  // workaround.
-  if (kdef->op() == "RandomStandardNormal") {
-    kdef->clear_constraint();
-    // Change the type constraint to permit only DTD_FLOAT.
-    KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
-    attr_constraint->set_name("dtype");
-    attr_constraint->mutable_allowed_values()->mutable_list()->add_type(
-        DT_FLOAT);
-    return true;
-  }
-  // TODO(b/26783907): The CPU backend currently does not implement sort.
-  if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") {
-    return false;
-  }
   if (kdef->op() == "Const") {
     AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
   }
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index b0eeee3..91d4812 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -90,6 +90,11 @@
                  << " have incompatible compile time constant inputs.";
     return false;
   }
+  if (x.is_metadata_op != y.is_metadata_op) {
+    LOG(WARNING) << "Registrations of " << x.name
+                 << " have incompatible values for is_metadata_op.";
+    return false;
+  }
   return true;
 }
 
@@ -350,6 +355,20 @@
   return &it->second.front()->compile_time_constant_inputs;
 }
 
+/*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) {
+  XlaOpRegistry& registry = Instance();
+  mutex_lock lock(registry.mutex_);
+  auto it = registry.ops_.find(op);
+  if (it == registry.ops_.end() || it->second.empty()) {
+    return false;
+  }
+
+  // The test in IsCompatible ensures that if there are multiple matching
+  // registrations for this op name, they all have the same value of
+  // is_metadata_op, so only the first match is returned.
+  return it->second.front()->is_metadata_op;
+}
+
 std::vector<string> XlaOpRegistry::BackendNames() {
   std::vector<string> names;
   XlaOpRegistry& registry = Instance();
@@ -432,6 +451,11 @@
   return *this;
 }
 
+XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() {
+  registration_->is_metadata_op = true;
+  return *this;
+}
+
 std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build(
     XlaOpRegistry::Factory factory) {
   registration_->factory = factory;
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index 74a4885..4b2c2ba 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -47,17 +47,18 @@
 
 constexpr std::array<DataType, 4> kFloatTypes = {
     {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
-constexpr std::array<DataType, 9> kNumericTypes = {
-    {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
-     DT_COMPLEX64, DT_BFLOAT16}};
+constexpr std::array<DataType, 11> kNumericTypes = {
+    {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
+     DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}};
 
-constexpr std::array<DataType, 9> kCpuAllTypes = {
-    {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
-     DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 14> kCpuAllTypes = {
+    {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
+     DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
 
-constexpr std::array<DataType, 10> kGpuAllTypes = {
-    {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
-     DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
+constexpr std::array<DataType, 15> kGpuAllTypes = {
+    {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
+     DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
+     DT_BFLOAT16}};
 
 // Class that manages registrations of operators and devices for the XLA JIT.
 // Not thread-safe.
@@ -136,6 +137,10 @@
   static const std::unordered_set<string>* CompileTimeConstantInputs(
       const string& op);
 
+  // Returns true if `op` is a "metadata" op, one that only looks at the shapes
+  // of its operands and not their values.
+  static bool IsMetadataOp(const string& op);
+
  private:
   friend class XlaBackendRegistrar;
   friend class XlaOpRegistrar;
@@ -192,6 +197,10 @@
     // Names of arguments that must be compile-time constants.
     std::unordered_set<string> compile_time_constant_inputs;
 
+    // True if this is a "metadata" op, one that only looks at the shapes of its
+    // operands and not their values.
+    bool is_metadata_op = false;
+
     // Factory used to build OpKernels that perform symbolic execution.
     Factory factory;
   };
@@ -256,6 +265,10 @@
   // Mark 'input_name' as an argument whose value must be known at compile-time.
   XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name);
 
+  // Mark this op as a "metadata" op, one that only looks at the shapes of its
+  // operands and not their values.
+  XlaOpRegistrationBuilder& IsMetadataOp();
+
   std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
       XlaOpRegistry::Factory factory);
 
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 76e36f3..cc7390c 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -193,6 +193,7 @@
         ":types",
         ":util",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/synchronization",
     ],
 )
 
@@ -244,6 +245,7 @@
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:regexp_internal",
+        "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:optional",
diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc
index 787725e..b507a2e 100644
--- a/tensorflow/compiler/xla/protobuf_util.cc
+++ b/tensorflow/compiler/xla/protobuf_util.cc
@@ -20,6 +20,7 @@
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/protobuf.h"
 
 namespace xla {
@@ -49,16 +50,40 @@
   return safe_file_name;
 }
 
+std::pair<tensorflow::mutex*, std::vector<std::function<string(string)>>*>
+GetDirectoryExpanders() {
+  static auto* mutex = new tensorflow::mutex;
+  static auto* singleton = new std::vector<std::function<string(string)>>;
+  return {mutex, singleton};
+}
+
+// Runs all the directory expanders over x and returns the result.
+string Expand(string x) {
+  auto pair = GetDirectoryExpanders();
+  tensorflow::mutex_lock lock(*pair.first);
+  for (const auto& f : *pair.second) {
+    x = f(x);
+  }
+  return x;
+}
+
 }  // namespace
 
 Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
                             const string& directory, const string& file_name) {
   tensorflow::Env* env = tensorflow::Env::Default();
-  TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
+  string expanded_dir = Expand(directory);
+  TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(expanded_dir));
   string safe_file_name = SanitizeFileName(file_name) + ".pb";
-  const string path = tensorflow::io::JoinPath(directory, safe_file_name);
+  const string path = tensorflow::io::JoinPath(expanded_dir, safe_file_name);
   return tensorflow::WriteBinaryProto(env, path, message);
 }
 
+void RegisterDirectoryExpander(const std::function<string(string)>& expander) {
+  auto pair = GetDirectoryExpanders();
+  tensorflow::mutex_lock lock(*pair.first);
+  pair.second->push_back(expander);
+}
+
 }  // namespace protobuf_util
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h
index 3667621..f22fc8b 100644
--- a/tensorflow/compiler/xla/protobuf_util.h
+++ b/tensorflow/compiler/xla/protobuf_util.h
@@ -39,6 +39,10 @@
 Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
                             const string& directory, const string& file_name);
 
+// Registers a function that may either expand a dirpath or forward the original
+// dirpath along as-is.
+void RegisterDirectoryExpander(const std::function<string(string)>& expander);
+
 }  // namespace protobuf_util
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 9da5dc0..cd5fd33 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -469,9 +469,11 @@
     absl::Span<const int64> window_strides,
     absl::Span<const std::pair<int64, int64>> padding,
     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
-    const ConvolutionDimensionNumbers& dimension_numbers) {
+    const ConvolutionDimensionNumbers& dimension_numbers,
+    int64 feature_group_count) {
   return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
-                                 lhs_dilation, rhs_dilation, dimension_numbers);
+                                 lhs_dilation, rhs_dilation, dimension_numbers,
+                                 feature_group_count);
 }
 
 LocalOp LocalComputationBuilder::ConvertElementType(
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 1d5dfe5..2166bb6 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -248,7 +248,8 @@
       absl::Span<const std::pair<int64, int64> > padding,
       absl::Span<const int64> lhs_dilation,
       absl::Span<const int64> rhs_dilation,
-      const ConvolutionDimensionNumbers& dimension_numbers);
+      const ConvolutionDimensionNumbers& dimension_numbers,
+      int64 feature_group_count);
 
   LocalOp ConvertElementType(const LocalOp& operand,
                              PrimitiveType new_element_type);
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index fa4366f..bb303c5 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -1109,7 +1109,7 @@
       dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
     return self._client.DotGeneral(lhs, rhs, dimension_numbers)
 
-  def Conv(self, lhs, rhs, window_strides, padding):
+  def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1):
     """Enqueues a Conv operation onto the computation.
 
     Args:
@@ -1117,6 +1117,7 @@
       rhs: LocalOp for the rank N+2 array of kernel weights.
       window_strides: length-N array-like of integer kernel strides.
       padding: PaddingType representing either 'SAME' or 'VALID' padding.
+      feature_group_count: number of feature groups for grouped convolution.
 
     Returns: a LocalOp representing the Conv operation.
     """
@@ -1125,10 +1126,11 @@
         self.GetShape(rhs).dimensions()[2:], window_strides)
     dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
     return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (),
-                                           (), dimension_numbers)
+                                           (), dimension_numbers,
+                                           feature_group_count)
 
   def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding,
-                             lhs_dilation, rhs_dilation):
+                             lhs_dilation, rhs_dilation, feature_group_count=1):
     """Enqueues a ConvWithGeneralPadding operation onto the computation.
 
     Args:
@@ -1138,6 +1140,7 @@
       padding: length-N array-like of pairs of integers of (low, high) padding.
       lhs_dilation: length-N array-like of dilation factors.
       rhs_dilation: length-N array-like of dilation factors.
+      feature_group_count: number of feature groups for grouped convolution.
 
     Returns:
       A ComputationdataHandle representing the added ConvWithGeneralPadding op.
@@ -1145,7 +1148,8 @@
     dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
     return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
                                            lhs_dilation, rhs_dilation,
-                                           dimension_numbers)
+                                           dimension_numbers,
+                                           feature_group_count)
 
   def _GetConvDimensionNumbers(self, num_spatial_dims):
     """Create ConvolutionDimensionNumbers proto for convolutions."""
@@ -1163,7 +1167,8 @@
     return dimension_numbers
 
   def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation,
-                         rhs_dilation, dimension_numbers):
+                         rhs_dilation, dimension_numbers,
+                         feature_group_count=1):
     """Enqueues a ConvGeneralDilated operation onto the computation.
 
     Args:
@@ -1190,6 +1195,7 @@
         labels appear in the rhs_spec string, so that window_strides[0] is
         matched with the dimension corresponding to the first character
         appearing in rhs_spec that is not 'I' or 'O'.
+      feature_group_count: number of feature groups for grouped convolution.
 
     Returns: a LocalOp representing the ConvGenralDilated operation.
     """
@@ -1215,7 +1221,8 @@
                  key=lambda i: rhs_spec.index(out_spec[i])))
     return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
                                            lhs_dilation, rhs_dilation,
-                                           dimension_numbers)
+                                           dimension_numbers,
+                                           feature_group_count)
 
   def Sort(self, operand, dimension=-1):
     """Enqueues a sort operation onto the computation."""
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index fd98e19..82103f0 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -661,6 +661,30 @@
                          [40., 50., 0.]]]])
     self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
 
+  def testConvGeneralDilatedGroupedConvolutionF32(self):
+    c = self._NewComputation()
+    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
+    lhs = a(1, 2, 2, 3)
+    rhs = a(2, 1, 1, 2) * 10
+    strides = [1, 1]
+    pads = [(1, 0), (0, 1)]
+    lhs_dilation = (2, 1)
+    rhs_dilation = (1, 1)
+    dimension_numbers = ("NCHW", "OIHW", "NCHW")
+    feature_group_count = 2
+    c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
+                         strides, pads, lhs_dilation, rhs_dilation,
+                         dimension_numbers, feature_group_count)
+    result = np.array([[[[0., 0., 0.],
+                         [10., 20., 0.],
+                         [0., 0., 0.],
+                         [40., 50., 0.]],
+                        [[0., 0., 0.],
+                         [330., 380., 160.],
+                         [0., 0., 0.],
+                         [480., 530., 220.]]]])
+    self._ExecuteAndCompareClose(c, expected=result)
+
   def testBooleanNot(self):
     c = self._NewComputation()
     arr = NumpyArrayBool([True, False, True])
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index fb80c78..2bc50c7 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -365,8 +365,11 @@
     hdrs = ["pattern_matcher.h"],
     deps = [
         ":hlo",
+        ":hlo_casting_utils",
+        "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:shape_util",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/utility",
     ],
 )
 
@@ -1166,6 +1169,7 @@
         ":hlo",
         ":hlo_matchers",
         ":hlo_module_group",
+        ":hlo_module_group_metadata",
         ":hlo_parser",
         ":hlo_proto",
         "//tensorflow/compiler/xla:test",
@@ -2557,6 +2561,7 @@
     ],
     deps = [
         ":hlo",
+        ":hlo_module_group",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:types",
@@ -2588,6 +2593,26 @@
     ],
 )
 
+tf_cc_test(
+    name = "hlo_pass_pipeline_test",
+    srcs = ["hlo_pass_pipeline_test.cc"],
+    deps = [
+        ":hlo",
+        ":hlo_parser",
+        ":hlo_pass_pipeline",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:test_helpers",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+        "//tensorflow/compiler/xla/tests:test_utils",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+    ],
+)
+
 cc_library(
     name = "hlo_cse",
     srcs = ["hlo_cse.cc"],
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 5458159..75dae7a 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -745,12 +745,25 @@
   }
   const int64 rhs_kept_dim = 1 - rhs_collapsing_dim;
 
-  auto reshape_if_necessary = [&](HloInstruction* hlo) {
-    if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
+  auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) {
+    if (hlo->shape().element_type() == element_type) {
       return hlo;
     }
-    return computation_->AddInstruction(
-        HloInstruction::CreateReshape(dot->shape(), hlo));
+    return computation_->AddInstruction(HloInstruction::CreateConvert(
+        ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo));
+  };
+
+  auto reshape_if_necessary = [&](HloInstruction* hlo) {
+    hlo = as_type(hlo, dot->shape().element_type());
+    if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
+      hlo = computation_->AddInstruction(
+          HloInstruction::CreateReshape(dot->shape(), hlo));
+    }
+    return hlo;
+  };
+
+  auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) {
+    return AddReduce(as_type(hlo, F32), dim);
   };
 
   auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape,
@@ -770,7 +783,7 @@
   if (ShapeUtil::Rank(rhs->shape()) == 1 &&
       ShapeUtil::Rank(lhs->shape()) == 1) {
     TF_RETURN_IF_ERROR(
-        ReplaceInstruction(dot, reshape_if_necessary(AddReduce(
+        ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32(
                                     multiply(Flatten(lhs), Flatten(rhs)), 0))));
     return true;
   }
@@ -804,17 +817,17 @@
       (ShapeUtil::Rank(lhs->shape()) == 2 &&
        lhs->shape().dimensions(lhs_kept_dim) == 1)) {
     if (ShapeUtil::Rank(rhs->shape()) == 1) {
-      TF_RETURN_IF_ERROR(ReplaceInstruction(
-          dot,
-          reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0))));
+      TF_RETURN_IF_ERROR(
+          ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32(
+                                      multiply(Flatten(lhs), rhs), 0))));
       return true;
     }
     TF_RETURN_IF_ERROR(ReplaceInstruction(
-        dot, reshape_if_necessary(
-                 AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(),
-                                                     rhs_collapsing_dim),
-                                    rhs),
-                           rhs_collapsing_dim))));
+        dot, reshape_if_necessary(add_reduce_in_f32(
+                 multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(),
+                                           rhs_collapsing_dim),
+                          rhs),
+                 rhs_collapsing_dim))));
     return true;
   }
 
@@ -826,7 +839,7 @@
       (ShapeUtil::Rank(rhs->shape()) == 2 &&
        rhs->shape().dimensions(rhs_kept_dim) == 1)) {
     TF_RETURN_IF_ERROR(ReplaceInstruction(
-        dot, reshape_if_necessary(AddReduce(
+        dot, reshape_if_necessary(add_reduce_in_f32(
                  multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(),
                                                 lhs_collapsing_dim)),
                  lhs_collapsing_dim))));
@@ -1061,7 +1074,8 @@
   const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension);
   const int n =
       right_operand->shape().dimensions(1 - rhs_contracting_dimension);
-  auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n});
+  auto memoized_shape =
+      ShapeUtil::MakeShape(dot->shape().element_type(), {m, n});
   auto* memoized_inst = computation_->AddInstruction(
       HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
                                 dnums, dot->precision_config()));
@@ -1109,10 +1123,12 @@
   HloInstruction *lhs, *rhs;
   CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
 
-  // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or
-  // below.
-  if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 ||
-      ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) {
+  // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are
+  // rank 2 or below.
+  if ((dot->shape().element_type() != F32 &&
+       dot->shape().element_type() != BF16) ||
+      ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 ||
+      ShapeUtil::Rank(dot->shape()) > 2) {
     return Status::OK();
   }
 
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
index b864c37..9f8d0ee 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.h
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -24,7 +24,7 @@
 namespace xla {
 
 // A pass which performs algebraic simplifications.
-class AlgebraicSimplifier : public HloPassInterface {
+class AlgebraicSimplifier : public HloModulePass {
  public:
   // Given shapes 'from_shape' and 'to_shape', determines if it is valid to
   // bitcast from 'from_shape' to 'to_shape' after considering platform
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 3fc1ba2..2047f89 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -3233,17 +3233,18 @@
 class DotStrengthReductionTest
     : public AlgebraicSimplifierTest,
       public ::testing::WithParamInterface<
-          ::testing::tuple<int, int, int, bool, bool>> {};
+          ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {};
 TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
   int m, k, n;
   bool transpose_lhs, transpose_rhs;
-  std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam();
+  PrimitiveType element_type;
+  std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam();
 
-  Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
-  Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
-  Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m});
-  Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
-  Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k});
+  Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n});
+  Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k});
+  Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m});
+  Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n});
+  Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k});
   HloComputation::Builder builder(TestName());
 
   auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -3285,7 +3286,7 @@
     DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
     ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
                        ::testing::Values(1, 2), ::testing::Bool(),
-                       ::testing::Bool()));
+                       ::testing::Bool(), ::testing::Values(F32, BF16)));
 
 struct DotOfConcatTestSpec {
   int64 m;
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h
index 79d37f0..5b625bf 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.h
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h
@@ -25,7 +25,7 @@
 // Normally these would live in the algebraic simplifier, but we want to run
 // this to fixpoint (this pass reaches fixed point in one execution) before we
 // run the DotDecomposer.
-class BatchDotSimplification : public HloPassInterface {
+class BatchDotSimplification : public HloModulePass {
  public:
   StatusOr<bool> Run(HloModule* module) override;
   absl::string_view name() const override;
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h
index 76e3217..147f3ae 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.h
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.h
@@ -26,7 +26,7 @@
 // A pass which rewrites batch norm operations into more operations. Breaking a
 // big operation into smaller operations helps leverage our generic fusion
 // logic.
-class BatchNormExpander : public HloPassInterface {
+class BatchNormExpander : public HloModulePass {
  public:
   // When use_fusion is set, a multi-output fusion node is created.
   BatchNormExpander(bool rewrite_training_op = false,
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
index 5dcd31b..cb3d12f 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
@@ -31,7 +31,7 @@
 // optimization pipeline followed by a DCE pass. If other passes are needed
 // after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the
 // changed made by this pass.
-class BFloat16ConversionFolding : public HloPassInterface {
+class BFloat16ConversionFolding : public HloModulePass {
  public:
   explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support)
       : bfloat16_support_(bfloat16_support) {}
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h
index 30b6346..f48e925 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.h
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h
@@ -25,7 +25,7 @@
 // A pass which adds F32 <-> BF16 conversions for HLO instructions that do not
 // support BF16 input/output or mixed precision, according to the passed-in
 // backend-specific BF16 support rules.
-class BFloat16Normalization : public HloPassInterface {
+class BFloat16Normalization : public HloModulePass {
  public:
   explicit BFloat16Normalization(const BFloat16Support* bfloat16_support)
       : bfloat16_support_(bfloat16_support) {}
@@ -48,7 +48,7 @@
 // use mixed precision; it removes mixed precision even if the backend supports
 // it. This pass is used to make the HLO module valid for other HLO passes which
 // do not support mixed precision.
-class BFloat16MixedPrecisionRemoval : public HloPassInterface {
+class BFloat16MixedPrecisionRemoval : public HloModulePass {
  public:
   BFloat16MixedPrecisionRemoval() {}
 
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h
index 1ee6497..6a62439 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.h
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h
@@ -58,7 +58,7 @@
 // BFloat16ConversionFolding. If other passes are needed after this pass, run
 // BFloat16MixedPrecisionRemoval first to undo some of the changes made by this
 // pass.
-class BFloat16Propagation : public HloPassInterface {
+class BFloat16Propagation : public HloModulePass {
  public:
   explicit BFloat16Propagation(const BFloat16Support* bfloat16_support);
 
diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h
index c5cd88b..08c4aff 100644
--- a/tensorflow/compiler/xla/service/call_inliner.h
+++ b/tensorflow/compiler/xla/service/call_inliner.h
@@ -25,7 +25,7 @@
 
 // For every kCall operation in the main computation, we inline the body of the
 // called function, and proceed recursively.
-class CallInliner : public HloPassInterface {
+class CallInliner : public HloModulePass {
  public:
   using InlinedInstructionMap =
       std::unordered_map<HloInstruction*, HloInstruction*>;
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index e5a6c28..96bd261 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -97,7 +97,7 @@
     TF_ASSIGN_OR_RETURN(
         std::unique_ptr<HloModule> hlo_module,
         HloModule::CreateFromProto(instance.computation, *module_config));
-    TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module));
+    TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module));
     hlo_modules.push_back(std::move(hlo_module));
   }
 
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h
index 3de50cb..2223ad6 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.h
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.h
@@ -25,7 +25,7 @@
 
 // HLO pass that removes kConditional with a constant predicate, replacing them
 // with their true or false computation as appropriate.
-class ConditionalSimplifier : public HloPassInterface {
+class ConditionalSimplifier : public HloModulePass {
  public:
   absl::string_view name() const override { return "simplify-conditional"; }
   StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
index 4988947..ce0138e 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
@@ -25,7 +25,7 @@
 
 // A pass which rewrites convolutions with feature_group_count > 1 into
 // convolutions with feature_group_count = 1.
-class ConvolutionFeatureGroupConverter : public HloPassInterface {
+class ConvolutionFeatureGroupConverter : public HloModulePass {
  public:
   ConvolutionFeatureGroupConverter() {}
 
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index d308f6b..c097089 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -43,7 +43,7 @@
 //   (3) The buffer set of the root instruction of the entry computation must be
 //       unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and
 //       InstructionAliasSet::IsDistinct return true.
-class CopyInsertion : public HloPassInterface {
+class CopyInsertion : public HloModulePass {
  public:
   absl::string_view name() const override { return "copy-insertion"; }
 
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 8cc522a..bf62798 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -180,6 +180,7 @@
         ":runtime_conv2d_mkl",
         ":runtime_fft",
         ":runtime_fork_join",
+        ":runtime_key_value_sort",
         ":runtime_matmul",
         ":runtime_matmul_mkl",
         ":runtime_single_threaded_conv2d",
@@ -624,6 +625,18 @@
 )
 
 cc_library(
+    name = "runtime_key_value_sort",
+    srcs = ["runtime_key_value_sort.cc"],
+    hdrs = ["runtime_key_value_sort.h"],
+    copts = runtime_copts(),
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/core:framework_lite",
+        "//third_party/eigen3",
+    ],
+)
+
+cc_library(
     name = "runtime_fork_join",
     srcs = ["runtime_fork_join.cc"],
     hdrs = ["runtime_fork_join.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
index 59437e8..becee3f 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
@@ -31,7 +31,7 @@
 // called canonical convolutions). This pass expands non-canonical convolutions
 // into reshapes and canonical convolutions, so that these non-canonical
 // convolutions can run faster.
-class ConvCanonicalization : public HloPassInterface {
+class ConvCanonicalization : public HloModulePass {
  public:
   explicit ConvCanonicalization(
       const TargetMachineFeatures* target_machine_features)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
index d49f7d7..076235f 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
@@ -30,7 +30,7 @@
 //
 // TODO(b/62548313): Remove this when buffer assignment is smarter
 // (module-scoped).
-class CpuCopyInsertion : public HloPassInterface {
+class CpuCopyInsertion : public HloModulePass {
  public:
   absl::string_view name() const override { return "copy-insertion"; }
 
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
index 6af724b..a39a9d4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
@@ -23,7 +23,7 @@
 // This pass should run early in the HLO pipeline and checks for HLO constructs
 // which are not supported by the CPU backend and cannot be removed via HLO
 // transformations (eg, sparse layouts).
-class CpuHloSupportChecker : public HloPassInterface {
+class CpuHloSupportChecker : public HloModulePass {
  public:
   CpuHloSupportChecker() = default;
   ~CpuHloSupportChecker() override = default;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 8a44c38..7e15909 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -74,6 +74,30 @@
     "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
 extern const char* const kParallelForkJoinSymbolName =
     "__xla_cpu_runtime_ParallelForkJoin";
+extern const char* const kKeyValueSortPREDSymbolName =
+    "__xla_cpu_runtime_KeyValueSortPRED";
+extern const char* const kKeyValueSortS8SymbolName =
+    "__xla_cpu_runtime_KeyValueSortS8";
+extern const char* const kKeyValueSortU8SymbolName =
+    "__xla_cpu_runtime_KeyValueSortU8";
+extern const char* const kKeyValueSortS16SymbolName =
+    "__xla_cpu_runtime_KeyValueSortS16";
+extern const char* const kKeyValueSortU16SymbolName =
+    "__xla_cpu_runtime_KeyValueSortU16";
+extern const char* const kKeyValueSortF16SymbolName =
+    "__xla_cpu_runtime_KeyValueSortF16";
+extern const char* const kKeyValueSortS32SymbolName =
+    "__xla_cpu_runtime_KeyValueSortS32";
+extern const char* const kKeyValueSortU32SymbolName =
+    "__xla_cpu_runtime_KeyValueSortU32";
+extern const char* const kKeyValueSortF32SymbolName =
+    "__xla_cpu_runtime_KeyValueSortF32";
+extern const char* const kKeyValueSortS64SymbolName =
+    "__xla_cpu_runtime_KeyValueSortS64";
+extern const char* const kKeyValueSortU64SymbolName =
+    "__xla_cpu_runtime_KeyValueSortU64";
+extern const char* const kKeyValueSortF64SymbolName =
+    "__xla_cpu_runtime_KeyValueSortF64";
 
 extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
 }  // namespace runtime
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index aa0e967..e6345e0 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -63,6 +63,18 @@
 extern const char* const kAcquireOutfeedBufferForPopulationSymbolName;
 extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
 extern const char* const kParallelForkJoinSymbolName;
+extern const char* const kKeyValueSortPREDSymbolName;
+extern const char* const kKeyValueSortS8SymbolName;
+extern const char* const kKeyValueSortU8SymbolName;
+extern const char* const kKeyValueSortS16SymbolName;
+extern const char* const kKeyValueSortU16SymbolName;
+extern const char* const kKeyValueSortF16SymbolName;
+extern const char* const kKeyValueSortS32SymbolName;
+extern const char* const kKeyValueSortU32SymbolName;
+extern const char* const kKeyValueSortF32SymbolName;
+extern const char* const kKeyValueSortS64SymbolName;
+extern const char* const kKeyValueSortU64SymbolName;
+extern const char* const kKeyValueSortF64SymbolName;
 
 // All symbol names for XLA CPU runtime functions need to start with this
 // prefix.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index df8c2a6..c32f253 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -495,8 +495,150 @@
 }
 
 Status IrEmitter::HandleSort(HloInstruction* sort) {
-  // TODO(b/26783907): Implement sort on CPU.
-  return Unimplemented("Sort is not implemented on CPU.");
+  TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
+  auto keys = sort->operand(0);
+  auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
+  ShapeIndex keys_shape_index({});
+  ShapeIndex values_shape_index({});
+  if (values != nullptr) {
+    keys_shape_index = ShapeIndex({0});
+    values_shape_index = ShapeIndex({1});
+  }
+  auto keys_destination = GetAllocationSlice(*sort, keys_shape_index);
+  auto keys_destination_address =
+      EmitBufferPointer(keys_destination, keys->shape());
+  auto values_destination = GetAllocationSlice(*sort, values_shape_index);
+  llvm::Value* values_destination_address = nullptr;
+
+  // The sort is implemented in-place, therefore we first copy the operand
+  // buffer to the output buffer if they are not the same.
+  if (keys_destination != GetAllocationSlice(*keys)) {
+    int64 primitive_type_size =
+        ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type());
+    auto source_buffer = GetEmittedValueFor(keys);
+    int64 keys_size = ByteSizeOf(keys->shape());
+    MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size,
+           source_buffer,
+           /*SrcAlign=*/primitive_type_size, keys_size);
+  }
+  if (values != nullptr) {
+    values_destination_address =
+        EmitBufferPointer(values_destination, values->shape());
+    if (values_destination != GetAllocationSlice(*values)) {
+      int64 primitive_type_size =
+          ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type());
+      auto source_buffer = GetEmittedValueFor(values);
+      int64 values_size = ByteSizeOf(values->shape());
+      MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size,
+             source_buffer,
+             /*SrcAlign=*/primitive_type_size, values_size);
+    }
+  }
+
+  // Normalize the shape and the dimension to sort.
+  Shape normalized_keys_shape =
+      ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
+          keys->shape());
+  int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
+      keys->shape().layout())[sort->dimensions(0)];
+
+  int64 sort_dimension_elements =
+      normalized_keys_shape.dimensions(physical_dimension_to_sort);
+  int64 higher_dimensions = 1;
+  for (int64 i = 0; i < physical_dimension_to_sort; ++i) {
+    higher_dimensions *= normalized_keys_shape.dimensions(i);
+  }
+  int64 lower_dimensions = 1;
+  for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1;
+       i > physical_dimension_to_sort; --i) {
+    lower_dimensions *= normalized_keys_shape.dimensions(i);
+  }
+
+  PrimitiveType keys_type = keys->shape().element_type();
+  const char* fn_name = nullptr;
+  llvm::Type* keys_native_type = nullptr;
+  switch (keys_type) {
+    case PRED:
+      fn_name = runtime::kKeyValueSortPREDSymbolName;
+      keys_native_type = b_.getInt8PtrTy();
+      break;
+    case S8:
+      fn_name = runtime::kKeyValueSortS8SymbolName;
+      keys_native_type = b_.getInt8PtrTy();
+      break;
+    case U8:
+      fn_name = runtime::kKeyValueSortU8SymbolName;
+      keys_native_type = b_.getInt8PtrTy();
+      break;
+    case S16:
+      fn_name = runtime::kKeyValueSortS16SymbolName;
+      keys_native_type = b_.getInt16Ty()->getPointerTo();
+      break;
+    case U16:
+      fn_name = runtime::kKeyValueSortU16SymbolName;
+      keys_native_type = b_.getInt16Ty()->getPointerTo();
+      break;
+    case F16:
+      fn_name = runtime::kKeyValueSortF16SymbolName;
+      keys_native_type = b_.getHalfTy()->getPointerTo();
+      break;
+    case S32:
+      fn_name = runtime::kKeyValueSortS32SymbolName;
+      keys_native_type = b_.getInt32Ty()->getPointerTo();
+      break;
+    case U32:
+      fn_name = runtime::kKeyValueSortU32SymbolName;
+      keys_native_type = b_.getInt32Ty()->getPointerTo();
+      break;
+    case F32:
+      fn_name = runtime::kKeyValueSortF32SymbolName;
+      keys_native_type = b_.getFloatTy()->getPointerTo();
+      break;
+    case S64:
+      fn_name = runtime::kKeyValueSortS64SymbolName;
+      keys_native_type = b_.getInt64Ty()->getPointerTo();
+      break;
+    case U64:
+      fn_name = runtime::kKeyValueSortU64SymbolName;
+      keys_native_type = b_.getInt64Ty()->getPointerTo();
+      break;
+    case F64:
+      fn_name = runtime::kKeyValueSortF64SymbolName;
+      keys_native_type = b_.getDoubleTy()->getPointerTo();
+      break;
+    default:
+      return Unimplemented(
+          "Element type %s not supported in the Sort op on CPU.",
+          PrimitiveType_Name(keys_type));
+  }
+
+  llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get(
+      b_.getVoidTy(),
+      {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
+       b_.getInt8PtrTy(), b_.getInt32Ty()},
+      /*isVarArg=*/false);
+  auto* key_value_sort_func = llvm::cast<llvm::Function>(
+      module_->getOrInsertFunction(fn_name, key_value_sort_type));
+  key_value_sort_func->setCallingConv(llvm::CallingConv::C);
+  key_value_sort_func->setDoesNotThrow();
+  key_value_sort_func->setOnlyAccessesArgMemory();
+  Call(key_value_sort_func,
+       {PointerCast(keys_destination_address, keys_native_type),
+        b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
+        b_.getInt64(lower_dimensions),
+        values != nullptr
+            ? PointerCast(values_destination_address, b_.getInt8PtrTy())
+            : llvm::Constant::getNullValue(b_.getInt8PtrTy()),
+        b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType(
+                                            values->shape().element_type())
+                                      : 0)});
+
+  if (values != nullptr) {
+    llvm_ir::EmitTuple(GetIrArrayFor(sort),
+                       {keys_destination_address, values_destination_address},
+                       &b_, module_);
+  }
+  return Status::OK();
 }
 
 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 3df9946..daafef4 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -163,6 +163,12 @@
   Status Preprocess(HloInstruction* hlo) override;
   Status Postprocess(HloInstruction* hlo) override;
 
+  // A convenient helper for calling BufferAssignment::GetUniqueSlice.
+  BufferAllocation::Slice GetAllocationSlice(
+      const HloInstruction& hlo, const ShapeIndex& index = {}) const {
+    return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie();
+  }
+
  private:
   // Private helper to initialize an IR function for the computation.
   void InitializeIrFunction(const string& function_name);
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index b4c0c09..ede7f43 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -142,6 +142,7 @@
       opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast ||
       opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
       opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
+      opcode == HloOpcode::kSort ||
       (opcode == HloOpcode::kConvolution &&
        PotentiallyImplementedAsEigenConvolution(*instruction,
                                                 target_machine_features_)) ||
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
index a99cd99..3822d53 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
@@ -60,7 +60,7 @@
 // own embedded computation, which is compiled as a parallel compute function,
 // and which is invoked from a kCall instruction that is lowered in codegen to
 // a runtime parallel fork/join call.
-class ParallelTaskAssigner : public HloPassInterface {
+class ParallelTaskAssigner : public HloModulePass {
  public:
   // 'max_parallelism': the maximum parallel task count per instruction.
   // 'shape_size': shape size function used by HloCostAnalysis during parallel
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
new file mode 100644
index 0000000..e0e7deb
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
@@ -0,0 +1,236 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace {
+using tensorflow::int16;
+using tensorflow::int32;
+using tensorflow::int64;
+using tensorflow::int8;
+using tensorflow::uint16;
+using tensorflow::uint32;
+using tensorflow::uint64;
+using tensorflow::uint8;
+
+template <typename KeyType>
+void KeyValueSort(std::pair<KeyType, int64>* row_to_sort, int64 num_elements) {
+  std::sort(row_to_sort, row_to_sort + num_elements);
+}
+
+// For floating point numbers, we want a total order comparator. -NaN and NaN
+// should appear at the beginning and end of the ordering, and -0.0 should
+// appear before 0.0. Also we want to have a stable sort, so if the keys are the
+// same, we compare the index values.
+template <typename KeyType>
+bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) {
+  bool lhs_is_negative = std::signbit(lhs);
+  bool rhs_is_negative = std::signbit(rhs);
+  // If the signs are different, we can just compare the signs.
+  if (lhs_is_negative != rhs_is_negative) {
+    return lhs_is_negative && !rhs_is_negative;
+  }
+  bool lhs_nan = std::isnan(lhs);
+  bool rhs_nan = std::isnan(rhs);
+  // Exactly one number is nan?
+  if (lhs_nan != rhs_nan) {
+    if (lhs_nan) {
+      return lhs_is_negative;
+    }
+    return !rhs_is_negative;
+  }
+  if (lhs != rhs) {
+    return lhs < rhs;
+  }
+  return lhs_index < rhs_index;
+}
+
+template <>
+void KeyValueSort(std::pair<double, int64>* row_to_sort, int64 num_elements) {
+  std::sort(row_to_sort, row_to_sort + num_elements,
+            [](const std::pair<double, int64>& lhs,
+               const std::pair<double, int64>& rhs) -> bool {
+              return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
+            });
+}
+
+template <>
+void KeyValueSort(std::pair<float, int64>* row_to_sort, int64 num_elements) {
+  std::sort(row_to_sort, row_to_sort + num_elements,
+            [](const std::pair<float, int64>& lhs,
+               const std::pair<float, int64>& rhs) -> bool {
+              return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
+            });
+}
+
+template <>
+void KeyValueSort(std::pair<Eigen::half, int64>* row_to_sort,
+                  int64 num_elements) {
+  std::sort(row_to_sort, row_to_sort + num_elements,
+            [](const std::pair<Eigen::half, int64>& lhs,
+               const std::pair<Eigen::half, int64>& rhs) -> bool {
+              return LessThan(
+                  Eigen::half_impl::half_to_float(lhs.first), lhs.second,
+                  Eigen::half_impl::half_to_float(rhs.first), rhs.second);
+            });
+}
+
+template <typename KeyType>
+void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values,
+                      int32 values_primitive_type_size_in_bytes) {
+  // High-level idea of the iteration/sorting logic:
+  // Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the
+  // dimension to sort, c is the product of the more minor dimensions (set to 1
+  // if b is the most minor dimension), and a is the product of the more major
+  // dimensions (set to 1 if b is the most major dimension). There are a * c
+  // many rows that we need to sort. We iterate through these, calculate a
+  // 'base_offset' value which points to the first element in that row, and add
+  // i * c for accessing the 'i'-th element in that row.
+
+  int64 sort_dimension_elements = b;
+  int64 num_iteration_elements = a * c;
+  int64 sort_dimension_offset = c;
+
+  std::unique_ptr<std::pair<KeyType, int64>[]> row_to_sort(
+      new std::pair<KeyType, int64>[sort_dimension_elements]);
+  std::unique_ptr<std::string[]> reordered_values(
+      new std::string[sort_dimension_elements]);
+  for (int64 index = 0; index < num_iteration_elements; ++index) {
+    // 'index' can be split into two values which index into the 'c' dimension
+    // and the 'a' dimension, respectively. 'index' % 'c' is the index into the
+    // 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When
+    // calculating the base offset, we need to multiply the index into the 'a'
+    // dimension with 'b' * 'c'.
+    // 'index' / 'c' * 'c' * 'b' = ('index' - 'index' % 'c') * 'b'.
+    int64 base_offset =
+        index % sort_dimension_offset +
+        (index - index % sort_dimension_offset) * sort_dimension_elements;
+    // TODO(b/26783907): We could define a custom iterator class that references
+    // both arrays. Then we could avoid the intermediate copy. However this
+    // would become more complicated, and it is not clear if the benefit is high
+    // enough.
+    for (int64 i = 0; i < sort_dimension_elements; ++i) {
+      row_to_sort[i] =
+          std::make_pair(keys[base_offset + i * sort_dimension_offset], i);
+    }
+    KeyValueSort(row_to_sort.get(), sort_dimension_elements);
+    for (int64 i = 0; i < sort_dimension_elements; ++i) {
+      keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first;
+    }
+    if (values == nullptr) {
+      continue;
+    }
+
+    // Reorder the values according to the order defined by the keys.
+    for (int64 i = 0; i < sort_dimension_elements; ++i) {
+      int64 memory_index =
+          (base_offset + row_to_sort[i].second * sort_dimension_offset) *
+          values_primitive_type_size_in_bytes;
+
+      reordered_values[i] = std::string(values + memory_index,
+                                        values_primitive_type_size_in_bytes);
+    }
+    for (int64 i = 0; i < sort_dimension_elements; ++i) {
+      int64 memory_index = (base_offset + i * sort_dimension_offset) *
+                           values_primitive_type_size_in_bytes;
+      memcpy(values + memory_index, reordered_values[i].c_str(),
+             values_primitive_type_size_in_bytes);
+    }
+  }
+}
+}  // namespace
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED(
+    bool* keys, int64 a, int64 b, int64 c, char* values,
+    int32 values_primitive_type_size_in_bytes) {
+  KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8(
+    int8* keys, int64 a, int64 b, int64 c, char* values,
+    int32 values_primitive_type_size_in_bytes) {
+  KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8(
+    uint8* keys, int64 a, int64 b, int64 c, char* values,
+    int32 values_primitive_type_size_in_bytes) {
+  KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16(
+    int16* keys, int64 a, int64 b, int64 c, char* values,
+    int32 values_primitive_type_size_in_bytes) {
+  KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16(
+    uint16* keys, int64 a, int64 b, int64 c, char* values,
+    int32 values_primitive_type_size_in_bytes) {
+  KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16(
+    Eigen::half* keys, int64 a, int64 b, int64 c, char* values,
+    int32 values_primitive_type_size_in_bytes) {
+  KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32(
+    int32* keys, int64 a, int64 b, int64 c, char* values,
+    int32 values_primitive_type_size_in_bytes) {
+  KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32(
+    uint32* keys, int64 a, int64 b, int64 c, char* values,
+    int32 values_primitive_type_size_in_bytes) {
+  KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32(
+    float* keys, int64 a, int64 b, int64 c, char* values,
+    int32 values_primitive_type_size_in_bytes) {
+  KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64(
+    int64* keys, int64 a, int64 b, int64 c, char* values,
+    int32 values_primitive_type_size_in_bytes) {
+  KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64(
+    uint64* keys, int64 a, int64 b, int64 c, char* values,
+    int32 values_primitive_type_size_in_bytes) {
+  KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64(
+    double* keys, int64 a, int64 b, int64 c, char* values,
+    int32 values_primitive_type_size_in_bytes) {
+  KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h
new file mode 100644
index 0000000..28e35e8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b'
+// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr.
+// If 'values' is not nullptr, the elements in 'values' are reordered in such a
+// way that if the element at index 'i' in 'keys' was moved to index 'j', the
+// element at index 'i' in 'values' is also moved to index 'j' (which means that
+// the same elements correspond to each other as before).
+extern void __xla_cpu_runtime_KeyValueSortPRED(
+    bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+    char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS8(
+    tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b,
+    tensorflow::int64 c, char* values,
+    tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU8(
+    tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b,
+    tensorflow::int64 c, char* values,
+    tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS16(
+    tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b,
+    tensorflow::int64 c, char* values,
+    tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU16(
+    tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b,
+    tensorflow::int64 c, char* values,
+    tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF16(
+    Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b,
+    tensorflow::int64 c, char* values,
+    tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS32(
+    tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b,
+    tensorflow::int64 c, char* values,
+    tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU32(
+    tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b,
+    tensorflow::int64 c, char* values,
+    tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF32(
+    float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+    char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS64(
+    tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b,
+    tensorflow::int64 c, char* values,
+    tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU64(
+    tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b,
+    tensorflow::int64 c, char* values,
+    tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF64(
+    double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+    char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+}
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index bf98064..9ec0c8f 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -35,6 +35,7 @@
 #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
@@ -202,6 +203,18 @@
   REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
   REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
   REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
+  REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED);
+  REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8);
+  REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8);
+  REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16);
+  REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16);
+  REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16);
+  REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32);
+  REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32);
+  REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32);
+  REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64);
+  REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64);
+  REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64);
 
   registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
   registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index c55206e..4b129c9 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -180,3 +180,17 @@
         "//tensorflow/core:test_main",
     ],
 )
+
+tf_cc_test(
+    name = "cpu_key_value_sort_test",
+    srcs = ["cpu_key_value_sort_test.cc"],
+    deps = [
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_parser",
+        "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+        "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+    ],
+)
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
new file mode 100644
index 0000000..3934c03
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+class CpuKeyValueSortTest : public CpuCodegenTest {};
+
+TEST_F(CpuKeyValueSortTest, SortR1) {
+  const string hlo_text = R"(
+HloModule KeyValueSort
+
+ENTRY main {
+  a = f32[10] parameter(0)
+
+  ROOT result = f32[10] sort(f32[10] a), dimensions={0}
+}
+)";
+
+  string filecheck_pattern = R"(
+CHECK: call void @__xla_cpu_runtime_KeyValueSort
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(hlo_text));
+
+  CpuAotCompilationOptions options{
+      /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"",
+      /*entry_point_name=*/"entry",
+      /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+  CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
+                                /*match_optimized_ir=*/true);
+}
+
+}  // namespace
+}  // namespace cpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h
index c326beb..aaa41fc 100644
--- a/tensorflow/compiler/xla/service/defuser.h
+++ b/tensorflow/compiler/xla/service/defuser.h
@@ -25,7 +25,7 @@
 
 // A pass which replaces all fusion instructions with the equivalent un-fused
 // instructions.
-class Defuser : public HloPassInterface {
+class Defuser : public HloModulePass {
  public:
   Defuser() {}
   ~Defuser() override {}
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc
index ba2a674..b3549ac 100644
--- a/tensorflow/compiler/xla/service/despecializer.cc
+++ b/tensorflow/compiler/xla/service/despecializer.cc
@@ -24,7 +24,7 @@
 namespace {
 
 // Pass which strips control dependencies from all instructions in the module.
-class ControlDepRemover : public HloPassInterface {
+class ControlDepRemover : public HloModulePass {
  public:
   ControlDepRemover() = default;
   absl::string_view name() const override { return "control-dep-remover"; }
diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h
index 7be70ad..46dcc3a 100644
--- a/tensorflow/compiler/xla/service/despecializer.h
+++ b/tensorflow/compiler/xla/service/despecializer.h
@@ -30,7 +30,7 @@
 //
 // Current despecialization passes are Defuser, ImplicitBroadcastRemover,
 // and BFloat16MixedPrecisionRemoval.
-class Despecializer : public HloPassInterface {
+class Despecializer : public HloModulePass {
  public:
   Despecializer();
   absl::string_view name() const override { return "despecializer"; }
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h
index fc38e31..40e7a3b 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.h
+++ b/tensorflow/compiler/xla/service/dot_decomposer.h
@@ -23,7 +23,7 @@
 
 // DotDecomposer is a pass which decomposes batch Dot operations into a
 // sequence of smaller (R2) Dot operations.
-class DotDecomposer : public HloPassInterface {
+class DotDecomposer : public HloModulePass {
  public:
   // Decomposes batch Dot operations when 'decompose_batch_dot' is true.
   DotDecomposer(bool decompose_batch_dot = true)
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 4bb1e07..515267e 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -847,29 +847,34 @@
 
 StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
                                                       llvm::Value* x) {
-  if (prim_type != F32) {
-    // TODO(b/34339814): Implement inverse erf for F64.
+  if (prim_type != F16 && prim_type != F32 && prim_type != F64) {
     return Unimplemented(
         "Inverse erf is only implemented for element "
-        "type F32.");
+        "types F16, F32 and F64.");
   }
-  auto getFloat = [&](const float f) {
-    return llvm::ConstantFP::get(b_->getFloatTy(), f);
+
+  // Upcast half to float.
+  if (prim_type == F16) {
+    x = b_->CreateFPExt(x, b_->getFloatTy());
+  }
+
+  auto get_float = [&](const double f) {
+    return llvm::ConstantFP::get(x->getType(), f);
   };
-  auto multiply_add = [&](absl::Span<const float> coefficients,
+  auto multiply_add = [&](absl::Span<const double> coefficients,
                           llvm::Value* w) {
-    llvm::Value* p = getFloat(coefficients.front());
+    llvm::Value* p = get_float(coefficients.front());
     coefficients.remove_prefix(1);
     for (float coefficient : coefficients) {
-      p = FAdd(FMul(p, w), getFloat(coefficient));
+      p = FAdd(FMul(p, w), get_float(coefficient));
     }
     return p;
   };
 
   // Approximation for inverse error function from
   //   Giles, M., "Approximating the erfinv function".
-  // The approximation has the form:
-  //   w = log((1-x)*(1+x))
+  // The approximation has the form (float version):
+  //   w = -log((1-x)*(1+x))
   //   if ( w < 5 ) {
   //     w = w - 2.5
   //     p = sum_{i=1}^n lq[i]*w^i
@@ -879,46 +884,124 @@
   //   }
   //   return p*x
   llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
-      module_, llvm::Intrinsic::log, {b_->getFloatTy()});
+      module_, llvm::Intrinsic::log, {x->getType()});
 
-  llvm::Value* w = FNeg(
-      Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))}));
+  llvm::Value* w = FNeg(Call(
+      logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))}));
 
   llvm::Value* p_addr =
-      llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_);
+      llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_);
 
-  llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
-      FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
-  // Handle true BB.
-  SetToFirstInsertPoint(if_data.true_block, b_);
-  {
-    llvm::Value* lw = FSub(w, getFloat(2.5f));
-    absl::Span<const float> lq{
-        2.81022636e-08f,  3.43273939e-07f, -3.5233877e-06f,
-        -4.39150654e-06f, 0.00021858087f,  -0.00125372503f,
-        -0.00417768164f,  0.246640727f,    1.50140941f};
-    llvm::Value* p = multiply_add(lq, lw);
-    Store(p, p_addr);
+  if (prim_type == F16 || prim_type == F32) {
+    llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+        FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_);
+    // Handle true BB.
+    SetToFirstInsertPoint(if_data.true_block, b_);
+    {
+      llvm::Value* lw = FSub(w, get_float(2.5f));
+      absl::Span<const double> lq{
+          2.81022636e-08f,  3.43273939e-07f, -3.5233877e-06f,
+          -4.39150654e-06f, 0.00021858087f,  -0.00125372503f,
+          -0.00417768164f,  0.246640727f,    1.50140941f};
+      llvm::Value* p = multiply_add(lq, lw);
+      Store(p, p_addr);
+    }
+
+    // Handle false BB.
+    SetToFirstInsertPoint(if_data.false_block, b_);
+    {
+      llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+          module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
+
+      llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f));
+      absl::Span<const double> gq{
+          -0.000200214257f, 0.000100950558f, 0.00134934322f,
+          -0.00367342844f,  0.00573950773f,  -0.0076224613f,
+          0.00943887047f,   1.00167406f,     2.83297682f};
+      llvm::Value* p = multiply_add(gq, gw);
+      Store(p, p_addr);
+    }
+
+    SetToFirstInsertPoint(if_data.after_block, b_);
+  } else {
+    DCHECK(prim_type == F64);
+
+    llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+        FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_);
+
+    SetToFirstInsertPoint(if_data.true_block, b_);
+    {
+      llvm::Value* lw = FSub(w, get_float(3.125));
+      absl::Span<const double> c{
+          -3.6444120640178196996e-21, -1.685059138182016589e-19,
+          1.2858480715256400167e-18,  1.115787767802518096e-17,
+          -1.333171662854620906e-16,  2.0972767875968561637e-17,
+          6.6376381343583238325e-15,  -4.0545662729752068639e-14,
+          -8.1519341976054721522e-14, 2.6335093153082322977e-12,
+          -1.2975133253453532498e-11, -5.4154120542946279317e-11,
+          1.051212273321532285e-09,   -4.1126339803469836976e-09,
+          -2.9070369957882005086e-08, 4.2347877827932403518e-07,
+          -1.3654692000834678645e-06, -1.3882523362786468719e-05,
+          0.0001867342080340571352,   -0.00074070253416626697512,
+          -0.0060336708714301490533,  0.24015818242558961693,
+          1.6536545626831027356};
+      llvm::Value* p = multiply_add(c, lw);
+      Store(p, p_addr);
+    }
+
+    SetToFirstInsertPoint(if_data.false_block, b_);
+    llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse(
+        FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_);
+    SetToFirstInsertPoint(if_data_second.true_block, b_);
+    {
+      llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+          module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+      llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25));
+      absl::Span<const double> t1{
+          2.2137376921775787049e-09,  9.0756561938885390979e-08,
+          -2.7517406297064545428e-07, 1.8239629214389227755e-08,
+          1.5027403968909827627e-06,  -4.013867526981545969e-06,
+          2.9234449089955446044e-06,  1.2475304481671778723e-05,
+          -4.7318229009055733981e-05, 6.8284851459573175448e-05,
+          2.4031110387097893999e-05,  -0.0003550375203628474796,
+          0.00095328937973738049703,  -0.0016882755560235047313,
+          0.0024914420961078508066,   -0.0037512085075692412107,
+          0.005370914553590063617,    1.0052589676941592334,
+          3.0838856104922207635};
+      llvm::Value* p = multiply_add(t1, gw);
+      Store(p, p_addr);
+    }
+
+    SetToFirstInsertPoint(if_data_second.false_block, b_);
+    {
+      llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+          module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+      llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0));
+      absl::Span<const double> t2{
+          -2.7109920616438573243e-11, -2.5556418169965252055e-10,
+          1.5076572693500548083e-09,  -3.7894654401267369937e-09,
+          7.6157012080783393804e-09,  -1.4960026627149240478e-08,
+          2.9147953450901080826e-08,  -6.7711997758452339498e-08,
+          2.2900482228026654717e-07,  -9.9298272942317002539e-07,
+          4.5260625972231537039e-06,  -1.9681778105531670567e-05,
+          7.5995277030017761139e-05,  -0.00021503011930044477347,
+          -0.00013871931833623122026, 1.0103004648645343977,
+          4.8499064014085844221};
+      llvm::Value* p = multiply_add(t2, gw);
+      Store(p, p_addr);
+    }
+
+    SetToFirstInsertPoint(if_data.after_block, b_);
   }
-
-  // Handle false BB.
-  SetToFirstInsertPoint(if_data.false_block, b_);
-  {
-    llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
-        module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
-
-    llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f));
-    absl::Span<const float> gq{
-        -0.000200214257f, 0.000100950558f, 0.00134934322f,
-        -0.00367342844f,  0.00573950773f,  -0.0076224613f,
-        0.00943887047f,   1.00167406f,     2.83297682f};
-    llvm::Value* p = multiply_add(gq, gw);
-    Store(p, p_addr);
-  }
-
-  SetToFirstInsertPoint(if_data.after_block, b_);
   llvm::Value* p = Load(p_addr);
-  return FMul(p, x);
+  x = FMul(p, x);
+  // Trunc back to half if needed.
+  if (prim_type == F16) {
+    x = b_->CreateFPTrunc(x, b_->getHalfTy());
+  }
+  return x;
 }
 
 StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type,
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h
index 3cccec9..986970f 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph.h
+++ b/tensorflow/compiler/xla/service/flatten_call_graph.h
@@ -26,7 +26,7 @@
 // Flattening associates each call site with a unique computation (for
 // sequential calling contexts) This simplifies buffer assignment and
 // points-to analysis (see b/36865746 for details).
-class FlattenCallGraph : public HloPassInterface {
+class FlattenCallGraph : public HloModulePass {
  public:
   absl::string_view name() const override { return "flatten-call-graph"; }
 
diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h
index 7bd9ea5..2b39359 100644
--- a/tensorflow/compiler/xla/service/gather_expander.h
+++ b/tensorflow/compiler/xla/service/gather_expander.h
@@ -23,7 +23,7 @@
 // This pass rewrites gather operations into (roughly) while loops of dynamic
 // slices.  This lets backends that don't support gather directly to
 // nevertheless have a minimum level of support.
-class GatherExpander : public HloPassInterface {
+class GatherExpander : public HloModulePass {
  public:
   absl::string_view name() const override { return "gather_expander"; }
   StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 64b9683..cbee4db 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -68,9 +68,7 @@
 #    srcs = [
 #        "partition_assignment_test.cc",
 #    ],
-#    tags = [
-#        "requires-gpu-sm35",
-#    ],
+#    tags = tf_cuda_tests_tags(),
 #    deps = [
 #        ":partition_assignment",
 #        "//tensorflow/core:stream_executor_no_cuda",
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 3a23ac1..85f3682 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -29,21 +29,51 @@
 namespace xla {
 namespace gpu {
 
-using se::dnn::AlgorithmDesc;
+ConvolutionThunk::ConvolutionThunk(
+    const HloCustomCallInstruction* cudnn_call,
+    std::vector<BufferAllocation::Slice> operand_slices,
+    BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
+    BufferAllocation::Slice tuple_result_slice)
+    : Thunk(Kind::kConvolution, cudnn_call),
+      cudnn_call_(cudnn_call),
+      operand_buffers_(std::move(operand_slices)),
+      result_buffer_(result_slice),
+      scratch_buffer_(scratch_slice),
+      tuple_result_buffer_(tuple_result_slice) {}
 
 Status ConvolutionThunk::ExecuteOnStream(
     const BufferAllocations& buffer_allocations, se::Stream* stream,
     HloExecutionProfiler* profiler) {
   CudnnConvParams params;
+  TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
 
-  params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_);
-  params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_);
-  params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_);
+  switch (params.kind) {
+    case CudnnConvKind::kForward:
+      params.input_buf =
+          buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
+      params.filter_buf =
+          buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
+      params.output_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
+      break;
+    case CudnnConvKind::kBackwardInput:
+      params.input_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
+      params.filter_buf =
+          buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
+      params.output_buf =
+          buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
+      break;
+    case CudnnConvKind::kBackwardFilter:
+      params.input_buf =
+          buffer_allocations.GetDeviceAddress(operand_buffers_[0]);
+      params.filter_buf = buffer_allocations.GetDeviceAddress(result_buffer_);
+      params.output_buf =
+          buffer_allocations.GetDeviceAddress(operand_buffers_[1]);
+      break;
+  }
+
   se::DeviceMemoryBase scratch =
       buffer_allocations.GetDeviceAddress(scratch_buffer_);
 
-  TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
-
   auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
   TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
 
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index d7d1f91..f53bc54 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -42,24 +42,12 @@
   // Constructs a thunk for launching a DNN convolution.  When run, it will
   // write a tuple (result, scratch_memory) into `tuple_result_buffer`.
   //
-  // Note that "output" here doesn't refer to the output from running this
-  // thunk, but rather to the "output" of a hypothetical forward convolution
-  // that corresponds to this input+filter+output triple.  That is, the result
-  // generated by this thunk is "output" for forward convs, "input" for
-  // backward-input convs, and "filter" for backward-filter convs.
+  // operand_slices should be in the same order as cudnn_call->operands().
   ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
-                   BufferAllocation::Slice input_slice,
-                   BufferAllocation::Slice filter_slice,
-                   BufferAllocation::Slice output_slice,
+                   std::vector<BufferAllocation::Slice> operand_slices,
+                   BufferAllocation::Slice result_slice,
                    BufferAllocation::Slice scratch_slice,
-                   BufferAllocation::Slice tuple_result_slice)
-      : Thunk(Kind::kConvolution, cudnn_call),
-        cudnn_call_(cudnn_call),
-        input_buffer_(std::move(input_slice)),
-        filter_buffer_(std::move(filter_slice)),
-        output_buffer_(std::move(output_slice)),
-        scratch_buffer_(std::move(scratch_slice)),
-        tuple_result_buffer_(std::move(tuple_result_slice)) {}
+                   BufferAllocation::Slice tuple_result_slice);
 
   ConvolutionThunk(const ConvolutionThunk&) = delete;
   ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -71,9 +59,8 @@
 
  private:
   const HloCustomCallInstruction* cudnn_call_;
-  BufferAllocation::Slice input_buffer_;
-  BufferAllocation::Slice filter_buffer_;
-  BufferAllocation::Slice output_buffer_;
+  std::vector<BufferAllocation::Slice> operand_buffers_;
+  BufferAllocation::Slice result_buffer_;
   BufferAllocation::Slice scratch_buffer_;
   BufferAllocation::Slice tuple_result_buffer_;
 };
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
index 6e2e330..c3f5850 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
@@ -52,7 +52,7 @@
 // The GPU backend does not implement a lowering for the batchnorm HLOs -- it
 // expects them to be lowered to cudnn calls via this pass or to HLO soup via
 // BatchNormRewriter.
-class CudnnBatchNormRewriter : public HloPassInterface {
+class CudnnBatchNormRewriter : public HloModulePass {
  public:
   absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; }
   StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index c607aea..f528e62 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -221,25 +221,12 @@
     allocator = &*se_allocator;
   }
 
-  // Allocate space for the input, filter, and output of the convolution.  We
-  // use a ScratchAllocator for this instead of calling allocator_ directly so
-  // that our allocations don't leak.
-  ScratchAllocator input_output_allocator(device_ordinal, allocator);
-  TF_ASSIGN_OR_RETURN(params.input_buf,
-                      input_output_allocator.AllocateBytes(
-                          &stream, ShapeUtil::ByteSizeOf(input_shape)));
-  TF_ASSIGN_OR_RETURN(params.filter_buf,
-                      input_output_allocator.AllocateBytes(
-                          &stream, ShapeUtil::ByteSizeOf(filter_shape)));
-  TF_ASSIGN_OR_RETURN(params.output_buf,
-                      input_output_allocator.AllocateBytes(
-                          &stream, ShapeUtil::ByteSizeOf(output_shape)));
-
-  if (cross_check_enabled) {
-    // Broadcast a constant to the buffer, instead of zeroing the buffer. A
-    // non-zero constant is useful for the cross checking, because zero-inputs
-    // may not always reveal the bugs.
-    const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) {
+  const auto initialize_buffer = [&stream, cross_check_enabled](
+                                     DeviceMemoryBase buffer) {
+    if (cross_check_enabled) {
+      // Broadcast a constant to the buffer, instead of zeroing the buffer. A
+      // non-zero constant is useful for the cross checking, because zero-inputs
+      // may not always reveal the bugs.
       CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4);
       size_t left_over_bytes = buffer.size() % 4;
       CHECK_EQ(0, left_over_bytes % 2);
@@ -257,19 +244,32 @@
       DeviceMemoryBase left_over(
           static_cast<char*>(buffer.opaque()) + aligned_size, left_over_bytes);
       stream.ThenMemcpy(&left_over, halfs, left_over_bytes);
-    };
-    initialize_f16(params.input_buf);
-    initialize_f16(params.filter_buf);
-    initialize_f16(params.output_buf);
-  } else {
-    // Although we don't have evidence this matters, zero out the buffers before
-    // autotuning.  It's conceivable that using uninitialized memory as the
-    // inputs might affect performance if e.g. the inputs contain denormals, and
-    // this is easy enough.
-    stream.ThenMemZero(&params.input_buf, params.input_buf.size())
-        .ThenMemZero(&params.filter_buf, params.filter_buf.size())
-        .ThenMemZero(&params.output_buf, params.output_buf.size());
-  }
+    } else {
+      // Although we don't have evidence this matters, zero out the buffers
+      // before autotuning.  It's conceivable that using uninitialized memory as
+      // the inputs might affect performance if e.g. the inputs contain
+      // denormals, and this is easy enough.
+      stream.ThenMemZero(&buffer, buffer.size());
+    }
+  };
+
+  // Allocate space for the input, filter, and output of the convolution.  We
+  // use a ScratchAllocator for this instead of calling allocator_ directly so
+  // that our allocations don't leak.
+  ScratchAllocator input_output_allocator(device_ordinal, allocator);
+  TF_ASSIGN_OR_RETURN(params.input_buf,
+                      input_output_allocator.AllocateBytes(
+                          &stream, ShapeUtil::ByteSizeOf(input_shape)));
+  TF_ASSIGN_OR_RETURN(params.filter_buf,
+                      input_output_allocator.AllocateBytes(
+                          &stream, ShapeUtil::ByteSizeOf(filter_shape)));
+  TF_ASSIGN_OR_RETURN(params.output_buf,
+                      input_output_allocator.AllocateBytes(
+                          &stream, ShapeUtil::ByteSizeOf(output_shape)));
+
+  initialize_buffer(params.input_buf);
+  initialize_buffer(params.filter_buf);
+  initialize_buffer(params.output_buf);
 
   DeviceMemoryBase* result_buf = [&] {
     switch (params.kind) {
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index f79b113..ce01895 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -30,7 +30,7 @@
 
 // Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
 // each and adding explicit scratch space to the CustomCalls.
-class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
+class CudnnConvolutionAlgorithmPicker : public HloModulePass {
  public:
   // If the `allocator` parameter is not null, we will use it to allocate temp
   // memory while timing the various convolution algorithms.  If it's null,
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
index fbe7e98..8d7c6fd 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
@@ -24,7 +24,7 @@
 
 // Rewrites plain convolutions, backwards-filter convolutions, and
 // backwards-input convolutions into CustomCall HLOs that call into cuDNN.
-class CudnnConvolutionRewriter : public HloPassInterface {
+class CudnnConvolutionRewriter : public HloModulePass {
  public:
   absl::string_view name() const override {
     return "cudnn-convolution-rewriter";
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 2a86ac2..3310ee8 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -92,9 +92,9 @@
   VLOG(3) << "tensor_ops_enabled: "
           << algorithm.algorithm().tensor_ops_enabled();
   VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind);
-  VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }";
-  VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }";
-  VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }";
+  VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape);
+  VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape);
+  VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape);
   VLOG(3) << "Window: { " << window.ShortDebugString() << " }";
   VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
 
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
index 7e3f577..f19996e 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
@@ -32,7 +32,7 @@
 // 2) The result of merging the fusion instruction into its users would not
 //    increase bytes transferred.
 //
-class FusionMerger : public HloPassInterface {
+class FusionMerger : public HloModulePass {
  public:
   absl::string_view name() const override { return "fusion merger"; }
 
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index 75f414e..79c74e7 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -34,15 +34,6 @@
 
 namespace gpu {
 
-StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy(
-    HloInstruction* hlo) {
-  HloInstruction*& copy = hlo_to_copy_map_[hlo];
-  if (copy == nullptr) {
-    TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo));
-  }
-  return copy;
-}
-
 StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
   CopyInsertion generic_copy_insertion;
 
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
index 8ffae18..4c7e38f 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
@@ -25,20 +25,11 @@
 // Besides the modifications made by the generic xla::CopyInsertion, this
 // GPU-specific copy insertion also materializes operands of library calls by
 // inserting kCopy instructions.
-class GpuCopyInsertion : public HloPassInterface {
+class GpuCopyInsertion : public HloModulePass {
  public:
   absl::string_view name() const override { return "copy-insertion"; }
 
   StatusOr<bool> Run(HloModule* module) override;
-
- protected:
-  // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making
-  // duplicate copies.
-  StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
-
-  // A map containing all copies inserted to materialize operands of library
-  // calls. The key is the copied instruction and the value is the copy.
-  tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> hlo_to_copy_map_;
 };
 
 }  // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
index bbb3340..9c64b4d 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
@@ -23,7 +23,7 @@
 // his pass should run early in the HLO pipeline and checks for HLO constructs
 // which are not supported by the GPU backend and cannot be removed via HLO
 // transformations (eg, sparse layouts).
-class GpuHloSupportChecker : public HloPassInterface {
+class GpuHloSupportChecker : public HloModulePass {
  public:
   GpuHloSupportChecker() = default;
   ~GpuHloSupportChecker() override = default;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index b669881..c792dd2 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -465,35 +465,18 @@
 
   if (IsCustomCallToDnnConvolution(*custom_call)) {
     const auto& assn = ir_emitter_context_->buffer_assignment();
-    auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
-    auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
+    std::vector<BufferAllocation::Slice> operand_slices;
+    operand_slices.reserve(custom_call->operand_count());
+    for (const auto* operand : custom_call->operands()) {
+      operand_slices.push_back(GetAllocationSlice(*operand));
+    }
     auto tuple_result_slice = GetAllocationSlice(*custom_call);
     auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
     auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
 
-    const auto& target = custom_call->custom_call_target();
-    BufferAllocation::Slice input_slice, filter_slice, output_slice;
-
-    if (target == kCudnnConvForwardCallTarget) {
-      input_slice = lhs_slice;
-      filter_slice = rhs_slice;
-      output_slice = conv_result_slice;
-    } else if (target == kCudnnConvBackwardInputCallTarget) {
-      input_slice = conv_result_slice;
-      filter_slice = rhs_slice;
-      output_slice = lhs_slice;
-    } else if (target == kCudnnConvBackwardFilterCallTarget) {
-      input_slice = lhs_slice;
-      filter_slice = conv_result_slice;
-      output_slice = rhs_slice;
-    } else {
-      LOG(FATAL) << "Unexpected custom call target: "
-                 << custom_call->custom_call_target();
-    }
-
     thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>(
-        Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice,
-        output_slice, scratch_slice, tuple_result_slice));
+        Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices),
+        conv_result_slice, scratch_slice, tuple_result_slice));
     return Status::OK();
   }
 
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
index 11dc56a..e592a37 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
@@ -30,7 +30,7 @@
 // targeting before running this pass.
 //
 // TODO(jlebar): Also pad dots.
-class PadForTensorCores : public HloPassInterface {
+class PadForTensorCores : public HloModulePass {
  public:
   absl::string_view name() const override { return "pad for tensor cores"; }
 
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
index a622e89..25cdf64 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
@@ -24,7 +24,7 @@
 // An HLO pass that canonicalizes convolution instructions for GPU codegen. It
 // inserts Pad instructions before Convolution instructions with uncanonicalized
 // padding, so that they can be lowered to cuDNN convolution.
-class PadInsertion : public HloPassInterface {
+class PadInsertion : public HloModulePass {
  public:
   absl::string_view name() const override { return "pad insertion"; }
 
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index db4a33dc..5da6f23 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -25,15 +25,17 @@
 )
 
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "tf_cuda_tests_tags",
+)
 
 cc_library(
     name = "gpu_codegen_test",
     testonly = True,
     srcs = ["gpu_codegen_test.cc"],
     hdrs = ["gpu_codegen_test.h"],
-    tags = [
-        "requires-gpu-sm35",
-    ],
+    tags = tf_cuda_tests_tags(),
     deps = [
         "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
         "//tensorflow/compiler/xla/service:gpu_plugin",
@@ -48,9 +50,7 @@
 tf_cc_test(
     name = "gpu_copy_test",
     srcs = ["gpu_copy_test.cc"],
-    tags = [
-        "requires-gpu-sm35",
-    ],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_codegen_test",
         "//tensorflow/compiler/xla:literal",
@@ -67,9 +67,7 @@
 tf_cc_test(
     name = "gpu_ftz_test",
     srcs = ["gpu_ftz_test.cc"],
-    tags = [
-        "requires-gpu-sm35",
-    ],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_codegen_test",
         "//tensorflow/core:test_main",
@@ -79,9 +77,7 @@
 tf_cc_test(
     name = "gpu_index_test",
     srcs = ["gpu_index_test.cc"],
-    tags = [
-        "requires-gpu-sm35",
-    ],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_codegen_test",
         "//tensorflow/compiler/xla:literal",
@@ -102,9 +98,7 @@
 tf_cc_test(
     name = "gpu_infeed_test",
     srcs = ["infeed_test.cc"],
-    tags = [
-        "requires-gpu-sm35",
-    ],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_codegen_test",
         "//tensorflow/compiler/xla:literal",
@@ -125,9 +119,7 @@
 tf_cc_test(
     name = "gpu_kernel_tiling_test",
     srcs = ["gpu_kernel_tiling_test.cc"],
-    tags = [
-        "requires-gpu-sm35",
-    ],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_codegen_test",
         "//tensorflow/compiler/xla/service:hlo",
@@ -142,7 +134,7 @@
 tf_cc_test(
     name = "gpu_ldg_test",
     srcs = ["gpu_ldg_test.cc"],
-    tags = ["requires-gpu-sm35"],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_codegen_test",
         "//tensorflow/compiler/xla:literal",
@@ -159,9 +151,7 @@
 tf_cc_test(
     name = "gpu_noalias_test",
     srcs = ["gpu_noalias_test.cc"],
-    tags = [
-        "requires-gpu-sm35",
-    ],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_codegen_test",
         "//tensorflow/compiler/xla:literal",
@@ -178,9 +168,7 @@
 tf_cc_test(
     name = "gpu_fusion_test",
     srcs = ["gpu_fusion_test.cc"],
-    tags = [
-        "requires-gpu-sm35",
-    ],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_codegen_test",
         "//tensorflow/compiler/xla/service:hlo_module_config",
@@ -194,9 +182,7 @@
 tf_cc_test(
     name = "gpu_unrolling_test",
     srcs = ["gpu_unrolling_test.cc"],
-    tags = [
-        "requires-gpu-sm35",
-    ],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_codegen_test",
         "//tensorflow/compiler/xla/service:hlo_module_config",
@@ -211,9 +197,7 @@
     name = "gpu_alignment_test",
     testonly = True,
     srcs = ["gpu_alignment_test.cc"],
-    tags = [
-        "requires-gpu-sm35",
-    ],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":gpu_codegen_test",
         "//tensorflow/compiler/xla/service:gpu_plugin",
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 8c6903d..e9e70b2 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -279,11 +279,11 @@
   return Status::OK();
 }
 
-void HloComputation::set_root_instruction(
-    HloInstruction* new_root_instruction) {
+void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
+                                          bool accept_different_shape) {
   // The shape of the root (ignoring layout) is an invariant of the computation
   // for non-fusion cases.
-  if (!IsFusionComputation()) {
+  if (!IsFusionComputation() && !accept_different_shape) {
     CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
                                 root_instruction_->shape()))
         << new_root_instruction->shape() << " is incompatible with "
@@ -916,13 +916,14 @@
   return CloneWithReplacements(
       /*replacements=*/std::unordered_map<const HloInstruction*,
                                           std::unique_ptr<HloInstruction>>(),
-      context, suffix);
+      /*extras=*/{}, context, suffix);
 }
 
 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
     std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
         replacements,
-    HloCloneContext* context, const string& suffix) {
+    absl::Span<HloInstruction*> extras, HloCloneContext* context,
+    const string& suffix) {
   std::unique_ptr<HloCloneContext> context_ptr;
   if (context == nullptr) {
     context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
@@ -944,6 +945,9 @@
 
   VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
   std::vector<HloInstruction*> postorder;
+  for (HloInstruction* instr : extras) {
+    postorder.push_back(instr);
+  }
   for (HloInstruction* instr : MakeInstructionPostOrder()) {
     if (HloInstruction* replacement = replace(instr)) {
       postorder.push_back(replacement);
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 91c5234..e7c98aa 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -134,9 +134,11 @@
   Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction);
 
   // Set the root of the computation to the given instruction. The instruction
-  // must have already been added to the computation and have the same shape as
-  // the result of the computation for non fusion computations.
-  void set_root_instruction(HloInstruction* new_root_instruction);
+  // must have already been added to the computation. In addition it must have
+  // the same shape as the result of the computation for non fusion
+  // computations, except if accept_different_shape is set to true.
+  void set_root_instruction(HloInstruction* new_root_instruction,
+                            bool accept_different_shape = false);
 
   // Return the root instruction of the computation. The root instruction is the
   // instruction which produces the output of the computation.
@@ -331,10 +333,13 @@
   //
   // If replacements maps a key to nullptr, we remove that instruction from the
   // new computation.
+  // If additional instructions are used by instructions in replacement map,
+  // they must be passed in post-order in the extras span.
   std::unique_ptr<HloComputation> CloneWithReplacements(
       std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
           replacements,
-      HloCloneContext* context = nullptr, const string& suffix = "clone");
+      absl::Span<HloInstruction*> extras, HloCloneContext* context = nullptr,
+      const string& suffix = "clone");
 
   // Returns true if the given instruction can be removed from the computation.
   // Parameter instructions cannot be removed without violating invariants of
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h
index 4557983..4a624cc 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.h
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h
@@ -23,7 +23,7 @@
 
 // A pass which performs constant folding in order to avoid unnecessary
 // computation on constants.
-class HloConstantFolding : public HloPassInterface {
+class HloConstantFolding : public HloModulePass {
  public:
   absl::string_view name() const override { return "constant_folding"; }
 
diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h
index a28c035..e4857fd 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.h
+++ b/tensorflow/compiler/xla/service/hlo_cse.h
@@ -25,7 +25,7 @@
 // and identical instructions with the same operands are commoned. The pass
 // iterates over the instructions in topological order which enables the pass to
 // find arbitrarily large common expressions.
-class HloCSE : public HloPassInterface {
+class HloCSE : public HloModulePass {
  public:
   // If is_layout_sensitive is true, then the simplifier preserves layout during
   // transformation. Otherwise, layout is ignored.
diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h
index 1fe69b1..4012042 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_dce.h
@@ -33,7 +33,7 @@
 //
 // This pass does not remove dead parameter instructions, as parameter
 // instructions cannot be deleted.
-class HloDCE : public HloPassInterface {
+class HloDCE : public HloModulePass {
  public:
   ~HloDCE() override {}
   absl::string_view name() const override { return "dce"; }
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
index d36631f..c0bf1b9 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
@@ -30,7 +30,7 @@
 // used to break an HLO graph edge connecting two instructions with different
 // sharding. If a set of connected instructions have all the same sharding, no
 // kDomain instruction will be placed.
-class HloDomainIsolator : public HloPassInterface {
+class HloDomainIsolator : public HloModulePass {
  public:
   // Creates a new kDomain instruction for the edge between the use instruction
   // (the first HloInstruction argument), and the operand instruction (the
diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h
index 97bc8ef..0fc30fb 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_remover.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h
@@ -26,7 +26,7 @@
 // Removes all the kDomain instructions of a given kind from the input module,
 // and calls the normalizer to propagate the properties on the possibly new born
 // instructions.
-class HloDomainRemover : public HloPassInterface {
+class HloDomainRemover : public HloModulePass {
  public:
   // Creates a new HloDomainRemover object tasked at removing all the kDomain
   // instructions of a given kind.
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
index 81d6d69..bea5cba 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
@@ -29,7 +29,7 @@
 
 // Verifies that the domain instructions are consistent, and the each domain is
 // surrounded by the same metadata.
-class HloDomainVerifier : public HloPassInterface {
+class HloDomainVerifier : public HloModulePass {
  public:
   HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {}
 
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
index 44ded2c..4d2a942 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
@@ -25,7 +25,7 @@
 // inserting Convert ops. This allows a backend to support an element type while
 // only actually implementing the Convert op for that element type. This is
 // generally not the fastest approach, but it works.
-class HloElementTypeConverter : public HloPassInterface {
+class HloElementTypeConverter : public HloModulePass {
  public:
   // eliminate_type is the type to eliminate as the input or output of ops,
   // using Convert ops to replace it with replace_with_type.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index e905f29..ad58833 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -2910,6 +2910,26 @@
   return os << ToString(kind);
 }
 
+bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
+                                  const HloInstruction* const& rhs) const {
+  if (rhs == nullptr) {
+    // Nothing compares less than nullptr.
+    return false;
+  }
+  if (lhs == nullptr) {
+    return true;
+  }
+  auto lhs_module = lhs->GetModule();
+  auto rhs_module = rhs->GetModule();
+  CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
+        (lhs_module != nullptr && rhs_module != nullptr));
+  if (lhs_module != nullptr &&
+      lhs_module->unique_id() != rhs_module->unique_id()) {
+    return lhs_module->unique_id() < rhs_module->unique_id();
+  }
+  return lhs->unique_id() < rhs->unique_id();
+}
+
 bool HloInstruction::CouldBeBitcast() const {
   switch (opcode_) {
     case HloOpcode::kTranspose:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 4f6cac1..d615df0 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1616,6 +1616,10 @@
   InstructionVector operands_;
 
   // The set of control predecessors of this instruction.
+  // Note that the order of the instructions in the vector influences the order
+  // computed in HloComputation::ComputeInstructionPostOrder, which may
+  // influence the result of the compilation by changing the scheduling. We are
+  // not sure if it matters.
   std::vector<HloInstruction*> control_predecessors_;
 
   // The users of this instruction. Users are HLOs where this instruction is an
@@ -1689,21 +1693,9 @@
 // To make the iteration order over the map deterministic, the comparator
 // should not be using the pointer values, but rather an intrinsic property of
 // the hlo. Exception: null pointer values compare less than non-null.
-//
-// Note that this cannot be used for HLO instructions across multiple modules
-// since the id of HLO instructions are only unique within each HLO module.
 struct HloPtrComparator {
   bool operator()(const HloInstruction* const& lhs,
-                  const HloInstruction* const& rhs) const {
-    if (rhs == nullptr) {
-      // Nothing compares less than nullptr.
-      return false;
-    }
-    if (lhs == nullptr) {
-      return true;
-    }
-    return lhs->unique_id() < rhs->unique_id();
-  }
+                  const HloInstruction* const& rhs) const;
 };
 
 template <typename ValueT>
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
index 3a1dd47..5bf055f 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
@@ -219,6 +219,33 @@
   }
 }
 
+// Makes sure that if a live instruction is within a computation used in control
+// flow operations, we mark live even other related instructions.
+void PropagateLivenessThroughControlFlow(
+    const HloInstruction* instruction,
+    HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
+    Workset* workset, CallGraph* call_graph) {
+  const CallGraphNode& call_graph_node =
+      call_graph->GetNode(instruction->parent());
+  if (call_graph_node.context() == CallContext::kSequential) {
+    for (const CallSite& callsite : call_graph_node.caller_callsites()) {
+      HloInstruction* caller = callsite.instruction();
+      if (caller->opcode() == HloOpcode::kWhile) {
+        // If a live instruction is within the %while body or condition
+        // computation, mark the predicate value returned by the condition
+        // computation live as well.
+        MarkLiveAtIndex(caller->while_condition()->root_instruction(), {},
+                        live_index_map, worklist, workset);
+      } else if (caller->opcode() == HloOpcode::kConditional) {
+        // If a live instruction is within the true or false branches of a
+        // conditional, we mark the predicate operand live as well.
+        MarkLiveAtIndex(caller->operand(0), {}, live_index_map, worklist,
+                        workset);
+      }
+    }
+  }
+}
+
 }  // namespace
 
 HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module)
@@ -257,12 +284,10 @@
     } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
       PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist,
                                   &workset);
-    } else if (instruction->opcode() == HloOpcode::kWhile &&
-               ShapeUtil::IsTuple(instruction->shape())) {
+    } else if (instruction->opcode() == HloOpcode::kWhile) {
       PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist,
                                     &workset);
-    } else if (instruction->opcode() == HloOpcode::kParameter &&
-               ShapeUtil::IsTuple(instruction->shape())) {
+    } else if (instruction->opcode() == HloOpcode::kParameter) {
       PropagateLivenessToParameterCallers(instruction, &live_index_map_,
                                           &worklist, &workset,
                                           call_graph_.get());
@@ -277,6 +302,8 @@
         MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset);
       }
     }
+    PropagateLivenessThroughControlFlow(instruction, &live_index_map_,
+                                        &worklist, &workset, call_graph_.get());
   }
 }
 
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
index 01b625c..e0ae117 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
@@ -398,5 +398,89 @@
   EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2}));
 }
 
+TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
+  auto module = ParseHloString(R"(
+  HloModule OutfeedLoop
+  WhileBody {
+    body_param = (s32[]) parameter(0)
+    token = token[] after-all()
+    constant.2 = s32[] constant(2)
+    outfeed_tuple = (s32[]) outfeed(constant.2, token)
+    get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+    constant.1 = s32[] constant(1)
+    add = s32[] add(get-tuple-element.1, constant.1)
+    ROOT tuple = (s32[]) tuple(add)
+  }
+  WhileCondition {
+    cond_param = (s32[]) parameter(0)
+    get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+    constant.2 = s32[] constant(10)
+    ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+  }
+  ENTRY SimpleLoop {
+    constant.3 = s32[] constant(0)
+    tuple.1 = (s32[]) tuple(constant.3)
+    while = (s32[]) while(tuple.1), condition=WhileCondition,
+      body=WhileBody
+    ROOT rtuple = () tuple()
+  })")
+                    .ValueOrDie();
+
+  const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+  EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
+  EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+}
+
+TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) {
+  auto module = ParseHloString(R"(
+  HloModule OutfeedLoop
+  InnerWhileBody {
+    body_param = (s32[]) parameter(0)
+    token = token[] after-all()
+    constant.2 = s32[] constant(2)
+    outfeed_tuple = (s32[]) outfeed(constant.2, token)
+    get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+    constant.1 = s32[] constant(1)
+    add = s32[] add(get-tuple-element.1, constant.1)
+    ROOT tuple = (s32[]) tuple(add)
+  }
+  InnerWhileCondition {
+    cond_param = (s32[]) parameter(0)
+    get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+    constant.2 = s32[] constant(10)
+    ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+  }
+  OuterWhileCondition {
+    cond_param.2 = (s32[]) parameter(0)
+    get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0
+    constant.5 = s32[] constant(5)
+    ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5)
+  }
+  OuterWhileBody {
+    body_param.2 = (s32[]) parameter(0)
+    get-tuple-element.8 = s32[] get-tuple-element(body_param.2), index=0
+    constant.6 = s32[] constant(0)
+    tuple.2 = (s32[]) tuple(constant.6)
+    inner_while = (s32[]) while(tuple.2), condition=InnerWhileCondition,
+      body=InnerWhileBody
+    constant.7 = s32[] constant(1)
+    add.2 = s32[] add(get-tuple-element.8, constant.7)
+    ROOT rtuple = (s32[]) tuple(add.2)
+  }
+  ENTRY SimpleLoop {
+    constant.3 = s32[] constant(0)
+    tuple.1 = (s32[]) tuple(constant.3)
+    while = (s32[]) while(tuple.1), condition=OuterWhileCondition,
+      body=OuterWhileBody
+    ROOT rtuple = () tuple()
+  })")
+                    .ValueOrDie();
+
+  const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+  EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
+  EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {}));
+  EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
index 5e02868..9964c6f 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
@@ -90,7 +90,7 @@
 // A pass which schedules the HLO instructions in a module. The HloModule's
 // schedule field is set to the resulting HloSchedule using
 // HloModule::set_schedule.
-class HloMemoryScheduler : public HloPassInterface {
+class HloMemoryScheduler : public HloModulePass {
  public:
   // size_function is the function returning the number of bytes required for a
   // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
@@ -109,7 +109,7 @@
 
 // A trivial pass which clears the schedule currently set on the
 // HloModule. After this pass runs HloModudle::has_schedule will return false.
-class HloDescheduler : public HloPassInterface {
+class HloDescheduler : public HloModulePass {
  public:
   HloDescheduler() = default;
   ~HloDescheduler() override = default;
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 3bc2d13..735804e 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -63,6 +63,7 @@
   // tests). The versioned handle is used by the service in the compilation
   // cache. A default configuration is created for this module.
   explicit HloModule(const string& name, const HloModuleConfig& config);
+  virtual ~HloModule() {}
 
   // Adds an entry computation to the module. A module can only have one entry
   // computation. Returns a pointer to the newly added computation.
@@ -87,6 +88,7 @@
       const std::unordered_map<HloComputation*, HloComputation*>& replacements);
 
   const string& name() const { return name_; }
+  void set_name(string name) { name_ = std::move(name); }
 
   // Returns a deep copy of this module including all computations.
   std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
@@ -255,7 +257,7 @@
       std::unique_ptr<HloComputation> computation, bool is_entry,
       bool uniquify_identifiers);
 
-  const string name_;
+  string name_;
   HloModuleConfig config_;
   HloComputation* entry_computation_ = nullptr;
   std::vector<std::unique_ptr<HloComputation>> computations_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc
index f7be5ca..31d26cc 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc
@@ -50,9 +50,7 @@
       auto* while_body_root = while_body_comp->root_instruction();
 
       if (!ShapeUtil::IsTuple(xla_while->shape()) ||
-          while_body_root->opcode() != HloOpcode::kTuple ||
-          while_body_comp->HasSideEffect() ||
-          xla_while->while_condition()->HasSideEffect()) {
+          while_body_root->opcode() != HloOpcode::kTuple) {
         // Only run DCE on tuple-shaped while loops where body root is Tuple,
         // with no I/O instructions.
         VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString();
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h
index 12ca234..d472211 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.h
@@ -28,7 +28,7 @@
 // Sweeps through live instructions which cross computation boundaries (kWhile),
 // and removes code at dead shape indices.
 //
-class HloModuleDCE : public HloPassInterface {
+class HloModuleDCE : public HloModulePass {
  public:
   ~HloModuleDCE() override {}
   absl::string_view name() const override { return "hlo-module-dce"; }
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 9c01862..83352ef 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -392,22 +392,28 @@
   if (!ContainsKey(companion_set_index_, instruction1) &&
       !ContainsKey(companion_set_index_, instruction2)) {
     companion_sets_.push_back(
-        absl::make_unique<std::unordered_set<HloInstruction*>>());
+        absl::make_unique<std::vector<HloInstruction*>>());
     auto companion_set = companion_sets_.back().get();
-    companion_set->insert(instruction1);
-    companion_set->insert(instruction2);
+    companion_set->push_back(instruction1);
+    companion_set->push_back(instruction2);
     companion_set_index_[instruction1] = companion_sets_.size() - 1;
     companion_set_index_[instruction2] = companion_sets_.size() - 1;
   } else if (!ContainsKey(companion_set_index_, instruction1)) {
-    companion_sets_[companion_set_index_[instruction2]]->insert(instruction1);
+    companion_sets_[companion_set_index_[instruction2]]->push_back(
+        instruction1);
     companion_set_index_[instruction1] = companion_set_index_[instruction2];
   } else if (!ContainsKey(companion_set_index_, instruction2)) {
-    companion_sets_[companion_set_index_[instruction1]]->insert(instruction2);
+    companion_sets_[companion_set_index_[instruction1]]->push_back(
+        instruction2);
     companion_set_index_[instruction2] = companion_set_index_[instruction1];
   } else if (companion_set_index_[instruction1] !=
              companion_set_index_[instruction2]) {
-    companion_sets_[companion_set_index_[instruction1]]->insert(
-        Companions(instruction2).begin(), Companions(instruction2).end());
+    // At any point while building the companion sets, each instruction belongs
+    // to at most 1 companion set, so the union of two companion sets is
+    // concatenating two disjoint sets.
+    absl::c_copy(Companions(instruction2),
+                 std::back_inserter(
+                     *companion_sets_[companion_set_index_[instruction1]]));
     int64 index_to_remove = companion_set_index_[instruction2];
     for (HloInstruction* hlo : Companions(instruction2)) {
       companion_set_index_[hlo] = companion_set_index_[instruction1];
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 768b0c7..278d94c 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -169,14 +169,14 @@
   // Returns the companion instructions for the given instruction.
   //
   // Precondition: IsCompanionWhile(instruction) is true.
-  const std::unordered_set<HloInstruction*>& Companions(
+  const std::vector<HloInstruction*>& Companions(
       const HloInstruction* instruction) const {
     CHECK_EQ(companion_set_index_.count(instruction), 1);
     return companion_set(companion_set_index_.at(instruction));
   }
 
   // Returns the companion set at the given index.
-  const std::unordered_set<HloInstruction*>& companion_set(int64 index) const {
+  const std::vector<HloInstruction*>& companion_set(int64 index) const {
     CHECK_LT(index, companion_sets_.size());
     return *companion_sets_[index];
   }
@@ -187,7 +187,7 @@
   }
 
   // Returns the list of all companion sets in the HLO module group.
-  const std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>&
+  const std::vector<std::unique_ptr<std::vector<HloInstruction*>>>&
   companion_sets() const {
     return companion_sets_;
   }
@@ -247,8 +247,7 @@
   void DumpCollectedStats() const;
 
   // List of all companion instructions sets in the module.
-  std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>
-      companion_sets_;
+  std::vector<std::unique_ptr<std::vector<HloInstruction*>>> companion_sets_;
 
   // Map from each companion while instruction to the index into companion_set_.
   tensorflow::gtl::FlatMap<const HloInstruction*, int64> companion_set_index_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
index ebf790b..b7b12cb 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
@@ -17,6 +17,7 @@
 
 #include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
 #include "tensorflow/compiler/xla/service/hlo_parser.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@@ -137,6 +138,69 @@
               ::testing::ElementsAre(op::Parameter()));
 }
 
+// Tests that the order of companion instructions in the companion set doesn't
+// change across runs.
+TEST_F(HloModuleGroupTest, ModuleGroupCompanionOrder) {
+  // A simple while loop template for core i sending to core i+1.
+  constexpr char text[] = R"(
+HloModule module_%d
+
+while_cond {
+  ROOT p = pred[] constant(true)
+}
+
+while_body {
+  param = s32[] parameter(0)
+  token.s = token[] after-all()
+  token.r = token[] after-all()
+  send = (s32[], u32[], token[]) send(param, token.s), channel_id=%d
+  send-done = token[] send-done(send), channel_id=%d
+  recv = (s32[], u32[], token[]) recv(token.r), channel_id=%d
+  ROOT recv-done = (s32[], token[]) recv-done(recv), channel_id=%d
+}
+
+ENTRY entry {
+  while_init = s32[] constant(1)
+  ROOT while = s32[] while(while_init), condition=while_cond, body=while_body
+}
+)";
+
+  // Try creating the module and the metadata kTrialCount times and check the
+  // companion instructions remain in the same order.
+  const int64 kTrialCount = 5;
+  const int64 kDeviceCount = 10;
+  std::vector<int64> companion_order;
+
+  for (int64 t = 0; t < kTrialCount; ++t) {
+    HloModuleGroup group(TestName());
+    for (int64 i = 0; i < kDeviceCount; ++i) {
+      const int64 send_channel = i;
+      const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1;
+      TF_ASSERT_OK_AND_ASSIGN(
+          std::unique_ptr<HloModule> module,
+          ParseHloString(absl::StrFormat(text, i, send_channel, send_channel,
+                                         recv_channel, recv_channel)));
+      group.push_back(std::move(module));
+    }
+    ASSERT_EQ(group.modules().size(), kDeviceCount);
+
+    TF_ASSERT_OK_AND_ASSIGN(auto metadata,
+                            HloModuleGroupMetadata::Build(group.modules()));
+    ASSERT_EQ(metadata->companion_sets().size(), 1);
+
+    std::vector<int64> module_ids;
+    for (HloInstruction* companion : *metadata->companion_sets()[0]) {
+      module_ids.push_back(metadata->GetModuleId(companion->GetModule()));
+    }
+
+    if (t == 0) {
+      companion_order = module_ids;
+    } else {
+      EXPECT_TRUE(absl::c_equal(companion_order, module_ids));
+    }
+  }
+}
+
 }  // namespace
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 11caa89..37197b2 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -64,14 +64,11 @@
  public:
   using LocTy = HloLexer::LocTy;
 
-  explicit HloParser(absl::string_view str, const HloModuleConfig& config)
-      : lexer_(str), config_(config) {}
+  explicit HloParser(absl::string_view str) : lexer_(str) {}
 
-  // Runs the parser. Returns false if an error occurred.
-  bool Run();
-
-  // Returns the parsed HloModule.
-  std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
+  // Runs the parser and constructs the resulting HLO in the given (empty)
+  // HloModule. Returns false if an error occurred.
+  bool Run(HloModule* module);
 
   // Returns the error information.
   string GetError() const { return StrJoin(error_, "\n"); }
@@ -98,8 +95,8 @@
       const string& name, const optional<Shape>& shape = nullopt);
 
   // ParseXXX returns false if an error occurred.
-  bool ParseHloModule();
-  bool ParseComputations();
+  bool ParseHloModule(HloModule* module);
+  bool ParseComputations(HloModule* module);
   bool ParseComputation(HloComputation** entry_computation);
   bool ParseInstructionList(HloComputation::Builder* builder,
                             string* root_name);
@@ -293,9 +290,7 @@
       computation_pool_;
 
   HloLexer lexer_;
-  std::unique_ptr<HloModule> module_;
   std::vector<std::unique_ptr<HloComputation>> computations_;
-  const HloModuleConfig config_;
   std::vector<string> error_;
 
   // Function that gets invoked when we try to resolve an instruction
@@ -349,9 +344,9 @@
   return Error(lexer_.GetLoc(), msg);
 }
 
-bool HloParser::Run() {
+bool HloParser::Run(HloModule* module) {
   lexer_.Lex();
-  return ParseHloModule();
+  return ParseHloModule(module);
 }
 
 std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
@@ -366,7 +361,7 @@
 }
 
 // ::= 'HloModule' name computations
-bool HloParser::ParseHloModule() {
+bool HloParser::ParseHloModule(HloModule* module) {
   if (lexer_.GetKind() != TokKind::kw_HloModule) {
     return TokenError("expects HloModule");
   }
@@ -385,22 +380,20 @@
     return false;
   }
 
-  module_ = absl::make_unique<HloModule>(name, config_);
-
-  if (!ParseComputations()) {
+  module->set_name(name);
+  if (!ParseComputations(module)) {
     return false;
   }
 
   if (is_scheduled.has_value() && *is_scheduled) {
-    TF_CHECK_OK(
-        module_->set_schedule(ScheduleFromInstructionOrder(module_.get())));
+    TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
   }
 
   return true;
 }
 
 // computations ::= (computation)+
-bool HloParser::ParseComputations() {
+bool HloParser::ParseComputations(HloModule* module) {
   HloComputation* entry_computation = nullptr;
   do {
     if (!ParseComputation(&entry_computation)) {
@@ -416,21 +409,20 @@
     if ((entry_computation != nullptr &&
          computations_[i].get() != entry_computation) ||
         (entry_computation == nullptr && i != computations_.size() - 1)) {
-      module_->AddEmbeddedComputation(std::move(computations_[i]));
+      module->AddEmbeddedComputation(std::move(computations_[i]));
       continue;
     }
-    auto computation =
-        module_->AddEntryComputation(std::move(computations_[i]));
+    auto computation = module->AddEntryComputation(std::move(computations_[i]));
     // The parameters and result layouts were set to default layout. Here we
     // set the layouts to what the hlo text says.
     for (int p = 0; p < computation->num_parameters(); p++) {
       const Shape& param_shape = computation->parameter_instruction(p)->shape();
-      TF_CHECK_OK(module_->mutable_entry_computation_layout()
+      TF_CHECK_OK(module->mutable_entry_computation_layout()
                       ->mutable_parameter_layout(p)
                       ->CopyLayoutFromShape(param_shape));
     }
     const Shape& result_shape = computation->root_instruction()->shape();
-    TF_CHECK_OK(module_->mutable_entry_computation_layout()
+    TF_CHECK_OK(module->mutable_entry_computation_layout()
                     ->mutable_result_layout()
                     ->CopyLayoutFromShape(result_shape));
   }
@@ -3247,53 +3239,62 @@
 
 StatusOr<std::unique_ptr<HloModule>> ParseHloString(
     absl::string_view str, const HloModuleConfig& config) {
-  HloParser parser(str, config);
-  if (!parser.Run()) {
+  auto module = absl::make_unique<HloModule>(/*name=*/"", config);
+  HloParser parser(str);
+  if (!parser.Run(module.get())) {
     return InvalidArgument("Syntax error:\n%s", parser.GetError());
   }
-  return parser.ConsumeHloModule();
+  return std::move(module);
 }
 
 StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) {
-  HloModuleConfig config;
-  return ParseHloString(str, config);
+  auto module = absl::make_unique<HloModule>(/*name=*/"", HloModuleConfig());
+  HloParser parser(str);
+  if (!parser.Run(module.get())) {
+    return InvalidArgument("Syntax error:\n%s", parser.GetError());
+  }
+  return std::move(module);
+}
+
+Status ParseHloString(absl::string_view str, HloModule* module) {
+  TF_RET_CHECK(module->computation_count() == 0);
+  HloParser parser(str);
+  if (!parser.Run(module)) {
+    return InvalidArgument("Syntax error:\n%s", parser.GetError());
+  }
+  return Status::OK();
 }
 
 StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
     absl::string_view str, absl::string_view name) {
-  HloModuleConfig config;
-  HloParser parser(str, config);
+  HloParser parser(str);
   auto builder = absl::make_unique<HloComputation::Builder>(string(name));
   string root_name;
   TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
   std::unique_ptr<HloComputation> computation = builder->Build();
-  auto module = absl::make_unique<HloModule>(string(name), config);
+  auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig());
   module->AddEntryComputation(std::move(computation));
   return std::move(module);
 }
 
 StatusOr<HloSharding> ParseSharding(absl::string_view str) {
-  HloModuleConfig config;
-  HloParser parser(str, config);
+  HloParser parser(str);
   return parser.ParseShardingOnly();
 }
 
 StatusOr<Window> ParseWindow(absl::string_view str) {
-  HloModuleConfig config;
-  HloParser parser(str, config);
+  HloParser parser(str);
   return parser.ParseWindowOnly();
 }
 
 StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
     absl::string_view str) {
-  HloModuleConfig config;
-  HloParser parser(str, config);
+  HloParser parser(str);
   return parser.ParseConvolutionDimensionNumbersOnly();
 }
 
 StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
-  HloModuleConfig config;
-  HloParser parser(str, config);
+  HloParser parser(str);
   return parser.ParsePaddingConfigOnly();
 }
 
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 1882a18..3696035 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -30,18 +30,23 @@
 // For details about the syntax accepted by this parser, see
 // g3doc/hlo_parser.md.
 
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with the given config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with the given config.
 StatusOr<std::unique_ptr<HloModule>> ParseHloString(
     absl::string_view str, const HloModuleConfig& config);
 
+// Given a string in the HloModule::ToString() format, parses the string and
+// builds the HloModule in place at the given module pointer. 'module' must
+// point to an empty module (no computations).
+Status ParseHloString(absl::string_view str, HloModule* module);
+
 // Parses the text for a single HLO operation into an HLO module with a function
 // that runs that operation (with the same parameters) as its entry computation.
 StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
     absl::string_view str, absl::string_view name = "single_op");
 
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with default config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with default config.
 StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str);
 
 // Parses the result of HloSharding::ToString(), e.g. "{replicated}".
diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h
index f1ad0f9..fdaac34 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_interface.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h
@@ -17,6 +17,7 @@
 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
 
 #include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
@@ -25,15 +26,45 @@
 namespace xla {
 
 // Base class for HLO passes. These are used with the HloPassPipeline to
-// organize a sequence of passes.
+// organize a sequence of passes. An HLO pass should not extend this class
+// directly; it should extend HloModulePass or HloModuleGroupPass.
 class HloPassInterface {
  public:
   virtual ~HloPassInterface() = default;
   virtual absl::string_view name() const = 0;
 
-  // Run the pass on the given HLO module.  Return whether it modified the
+  // Run the pass on the given HLO module.  Returns whether it modified the
   // module.
   virtual StatusOr<bool> Run(HloModule* module) = 0;
+
+  // Run the pass on the given HLO module group. Returns whether it modified the
+  // module group. Ideally, the module group variant would be named "Run" as
+  // well, but C++ does not handle overloaded virtual methods well.
+  virtual StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) = 0;
+};
+
+// Base class for passes which are module-scoped.
+class HloModulePass : public HloPassInterface {
+ public:
+  // Runs the pass on a module group by iterating through each module in the
+  // group.
+  StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
+    bool changed = false;
+    for (HloModule* module : module_group->modules()) {
+      TF_ASSIGN_OR_RETURN(bool module_changed, Run(module));
+      changed |= module_changed;
+    }
+    return changed;
+  };
+};
+
+// Base class for passes which are module-group scoped. These passes cannot run
+// on an HLO module.
+class HloModuleGroupPass : public HloPassInterface {
+ public:
+  StatusOr<bool> Run(HloModule* module) override {
+    return InternalError("Module group pass cannot be run on a module");
+  }
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 6e4ed0d..8c2f928 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -17,7 +17,6 @@
 
 #include <functional>
 
-#include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
 #include "absl/strings/str_join.h"
 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
@@ -29,108 +28,128 @@
 #include "tensorflow/core/platform/logging.h"
 
 namespace xla {
-namespace {
 
-using absl::StrAppend;
-using absl::StrCat;
+template <typename HloT>
+Status HloPassPipeline::RunInvariantCheckers(
+    HloT* hlo, absl::string_view after_pass_name) {
+  for (auto& invariant_checker : invariant_checkers_) {
+    VLOG(1) << "    Invariant checker " << invariant_checker->name();
+    StatusOr<bool> changed_status = RunHelper(invariant_checker.get(), hlo);
+    VLOG(1) << "    Invariant checker done " << invariant_checker->name();
+    if (!changed_status.ok()) {
+      VLOG(2) << "Failed invariant check:";
+      XLA_VLOG_LINES(2, hlo->ToString());
+      return Status(changed_status.status().code(),
+                    absl::StrCat(changed_status.status().error_message(),
+                                 "\n\nFailed after ", after_pass_name));
+    }
+    TF_RET_CHECK(!changed_status.ValueOrDie())
+        << "invariant checkers must not change the graph";
+  }
+  return Status::OK();
+}
 
-void DumpModuleGraph(const HloModule& module, const string& message) {
+template <typename HloT>
+StatusOr<bool> HloPassPipeline::RunPassesInternal(
+    HloT* hlo, absl::Span<HloPassInterface* const> passes) {
+  string last_pass_name = "pipeline-start";
+  TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
+  bool changed = false;
+  for (HloPassInterface* pass : passes) {
+    VLOG(1) << "  HLO pass " << pass->name();
+    MaybeDumpHlo(*hlo,
+                 /*after_pass_name=*/last_pass_name,
+                 /*before_pass_name=*/pass->name());
+    TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
+    changed |= pass_changed;
+    TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name()));
+    last_pass_name = string(pass->name());
+  }
+  MaybeDumpHlo(*hlo,
+               /*after_pass_name=*/last_pass_name,
+               /*before_pass_name=*/"pipeline-end");
+  return changed;
+}
+
+std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
+    const DebugOptions& debug_options) {
+  auto repeated_field = debug_options.xla_disable_hlo_passes();
+  tensorflow::gtl::FlatSet<string> disabled_pass_names(repeated_field.begin(),
+                                                       repeated_field.end());
+  if (!disabled_pass_names.empty()) {
+    VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
+            << absl::StrJoin(disabled_pass_names, ", ");
+  }
+
+  std::vector<HloPassInterface*> enabled_passes;
+  for (auto& pass : passes_) {
+    if (disabled_pass_names.count(string(pass->name())) == 0) {
+      enabled_passes.push_back(pass.get());
+    }
+  }
+  return enabled_passes;
+}
+
+void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
+                                   absl::string_view after_pass_name,
+                                   absl::string_view before_pass_name) {
+  const string& proto_dump_path =
+      module.config().debug_options().xla_dump_per_pass_hlo_proto_to();
+  if (!proto_dump_path.empty()) {
+    static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
+    static auto* const module_id_to_pass_number =
+        new tensorflow::gtl::FlatMap<int64, int64>();
+
+    tensorflow::mutex_lock lock(mu);
+    const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
+
+    const string filename = SanitizeFileName(
+        absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
+                        pass_number, name(), after_pass_name));
+
+    TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(
+        MakeHloProto(module), proto_dump_path, filename));
+  }
+
+  const string message =
+      StrCat("after ", after_pass_name, ", before ", before_pass_name);
   hlo_graph_dumper::MaybeDumpHloModule(module, message);
   VLOG(3) << "HLO " << message << ":";
   XLA_VLOG_LINES(3, module.ToString());
 }
 
-void DumpModuleProto(const HloModule& module, const string& dump_to,
-                     const string& pipeline_name, const string& pass_name) {
-  static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
-  static auto* const module_id_to_pass_number =
-      new tensorflow::gtl::FlatMap<int64, int64>();
-
-  tensorflow::mutex_lock lock(mu);
-  const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
-
-  const string mod_name = SanitizeFileName(
-      absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
-                      pass_number, pipeline_name, pass_name));
-
-  TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module),
-                                                   dump_to, mod_name));
+void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group,
+                                   absl::string_view after_pass_name,
+                                   absl::string_view before_pass_name) {
+  for (const HloModule* module : module_group.modules()) {
+    MaybeDumpHlo(*module, after_pass_name, before_pass_name);
+  }
 }
-}  // namespace
 
 StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
   run_called_ = true;
 
-  VLOG(1) << "Running HLO pass pipeline " << name();
+  VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
+          << name();
 
-  auto repeated_field =
-      module->config().debug_options().xla_disable_hlo_passes();
-  tensorflow::gtl::FlatSet<string> disabled_passes(repeated_field.begin(),
-                                                   repeated_field.end());
-  if (!disabled_passes.empty()) {
-    VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
-            << absl::StrJoin(disabled_passes, ", ");
+  return RunPassesInternal(module,
+                           GetEnabledPasses(module->config().debug_options()));
+}
+
+StatusOr<bool> HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) {
+  run_called_ = true;
+
+  VLOG(1) << "Running HLO pass pipeline on module group "
+          << module_group->name() << ": " << name();
+
+  if (module_group->modules().empty()) {
+    VLOG(1) << "Module group is empty. Nothing to do.";
+    return false;
   }
 
-  auto run_invariant_checkers = [this,
-                                 module](const string& message) -> Status {
-    for (auto& invariant_checker : invariant_checkers_) {
-      VLOG(1) << "    Invariant checker " << invariant_checker->name();
-      StatusOr<bool> changed_status = invariant_checker->Run(module);
-      VLOG(1) << "    Invariant checker done " << invariant_checker->name();
-      if (!changed_status.ok()) {
-        VLOG(2) << "Module failed invariant check:";
-        XLA_VLOG_LINES(2, module->ToString());
-        return Status(changed_status.status().code(),
-                      StrCat(changed_status.status().error_message(),
-                             "\n\nFailed ", message));
-      }
-      TF_RET_CHECK(!changed_status.ValueOrDie())
-          << "invariant checkers must not change the graph";
-    }
-    return Status::OK();
-  };
-
-  string prefix = StrCat(name(), ": pipeline start");
-  bool changed = false;
-  string message;
-  TF_RETURN_IF_ERROR(
-      run_invariant_checkers(StrCat("before running pipeline: ", name())));
-  const string xla_dump_per_pass_hlo_proto_to =
-      module->config().debug_options().xla_dump_per_pass_hlo_proto_to();
-  if (!xla_dump_per_pass_hlo_proto_to.empty()) {
-    DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
-                    "pipeline_start");
-  }
-
-  for (auto& pass : passes_) {
-    if (disabled_passes.count(string(pass->name())) > 0) {
-      VLOG(1) << "  Skipping HLO pass " << pass->name()
-              << ", disabled by --xla_disable_hlo_passes";
-      continue;
-    }
-
-    VLOG(1) << "  HLO pass " << pass->name();
-
-    // Emit label containing: "after foo-pass, before bar-pass".
-    message.clear();
-    StrAppend(&message, prefix, ", before ", pass->name());
-    DumpModuleGraph(*module, message);
-
-    TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
-    TF_RETURN_IF_ERROR(
-        run_invariant_checkers(StrCat("after running pass: ", pass->name())));
-    if (!xla_dump_per_pass_hlo_proto_to.empty()) {
-      DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
-                      string(pass->name()));
-    }
-
-    changed |= changed_this_pass;
-    prefix.clear();
-    StrAppend(&prefix, name(), ": after ", pass->name());
-  }
-  DumpModuleGraph(*module, prefix + ", pipeline end");
-  return changed;
+  return RunPassesInternal(
+      module_group,
+      GetEnabledPasses(module_group->module(0).config().debug_options()));
 }
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
index 1d41a4d..09e7033 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -22,6 +22,7 @@
 #include <vector>
 
 #include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
 #include "tensorflow/compiler/xla/statusor.h"
@@ -61,10 +62,45 @@
     return *pass;
   }
 
-  // Run all passes on the given HLO module.
   StatusOr<bool> Run(HloModule* module) override;
+  StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override;
 
  private:
+  // Returns the set of passes which are enabled. DebugOptions can selectively
+  // disable passes via --xla_disable_hlo_passes flag.
+  std::vector<HloPassInterface*> GetEnabledPasses(
+      const DebugOptions& debug_options);
+
+  // Maybe dumps the given module or module group depending on flag values
+  // contained in DebugOptions of module config.
+  void MaybeDumpHlo(const HloModuleGroup& module_group,
+                    absl::string_view after_pass_name,
+                    absl::string_view before_pass_name);
+  void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name,
+                    absl::string_view before_pass_name);
+
+  // Runs the invariant checker on the given HLO. HloT can be either HloModule
+  // or HloModuleGroup.
+  template <typename HloT>
+  Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name);
+
+  // Helper which runs the given pass on the given HLO. HloT can be either
+  // HloModule or HloModuleGroup.
+  template <typename HloT>
+  StatusOr<bool> RunPassesInternal(HloT* hlo,
+                                   absl::Span<HloPassInterface* const> passes);
+
+  // Helpers which run the given passes on the given HLO construct. These
+  // helpers enable templating of the core of the pipeline logic by providing
+  // HloModule and HloModuleGroup specific methods with the same name.
+  static StatusOr<bool> RunHelper(HloPassInterface* pass, HloModule* module) {
+    return pass->Run(module);
+  }
+  static StatusOr<bool> RunHelper(HloPassInterface* pass,
+                                  HloModuleGroup* module_group) {
+    return pass->RunOnModuleGroup(module_group);
+  }
+
   const string name_;
   std::vector<std::unique_ptr<HloPassInterface>> passes_;
   std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
new file mode 100644
index 0000000..ee8cb12
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
@@ -0,0 +1,259 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class HloPassPipelineTest : public HloVerifiedTestBase {
+ protected:
+  StatusOr<HloModuleGroup> ParseModuleGroup(
+      absl::Span<const string> hlo_strings) {
+    HloModuleGroup group(TestName());
+    for (const string& hlo_string : hlo_strings) {
+      TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+      group.push_back(std::move(module));
+    }
+    return std::move(group);
+  }
+};
+
+// A module pass which renames instructions named 'foo' to 'bar'.
+class FooToBarModulePass : public HloModulePass {
+  absl::string_view name() const override { return "foo2bar"; }
+
+  StatusOr<bool> Run(HloModule* module) override {
+    bool changed = false;
+    for (HloComputation* computation : module->computations()) {
+      for (HloInstruction* instruction : computation->instructions()) {
+        if (instruction->name() == "foo") {
+          instruction->SetAndSanitizeName("bar");
+          changed = true;
+        }
+      }
+    }
+    return changed;
+  }
+};
+
+// A module group pass which renames instructions named 'baz' to 'qux'.
+class BazToQuxModuleGroupPass : public HloModuleGroupPass {
+  absl::string_view name() const override { return "baz2qux"; }
+
+  StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
+    bool changed = false;
+    for (HloModule* module : module_group->modules()) {
+      for (HloComputation* computation : module->computations()) {
+        for (HloInstruction* instruction : computation->instructions()) {
+          if (instruction->name() == "baz") {
+            instruction->SetAndSanitizeName("qux");
+            changed = true;
+          }
+        }
+      }
+    }
+    return changed;
+  }
+};
+
+// An invariant checker pass which returns an error if there exists an
+// instruction named 'bar'.
+class BarBlowerUpper : public HloModulePass {
+  absl::string_view name() const override { return "bar-blower-upper"; }
+
+  StatusOr<bool> Run(HloModule* module) override {
+    for (HloComputation* computation : module->computations()) {
+      for (HloInstruction* instruction : computation->instructions()) {
+        if (instruction->name() == "bar") {
+          return InternalError("Module has instruction named bar");
+        }
+      }
+    }
+    return false;
+  }
+};
+
+TEST_F(HloPassPipelineTest, ModulePassChanged) {
+  // Test an HLO module pass which changes a module.
+  const string module_str = R"(
+HloModule ModulePassChanged
+
+ENTRY main {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  ROOT foo = f32[] multiply(a, b)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(module_str));
+  HloPassPipeline pipeline(TestName());
+  pipeline.AddPass<FooToBarModulePass>();
+
+  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_EQ(root->name(), "foo");
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+  EXPECT_TRUE(changed);
+  EXPECT_EQ(root->name(), "bar");
+}
+
+TEST_F(HloPassPipelineTest, ModulePassUnchanged) {
+  // Test an HLO module pass which does not change a module.
+  const string module_str = R"(
+HloModule ModulePassUnchanged
+
+ENTRY main {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  ROOT blahblah = f32[] multiply(a, b)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(module_str));
+  HloPassPipeline pipeline(TestName());
+  pipeline.AddPass<FooToBarModulePass>();
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+  EXPECT_FALSE(changed);
+}
+
+TEST_F(HloPassPipelineTest, MixedPipeline) {
+  // Test a pipeline with both a module pass and a module group pass.
+  const string module_0_str = R"(
+HloModule MixedPipeline.1
+
+ENTRY main {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  ROOT baz = f32[] multiply(a, b)
+}
+)";
+  const string module_1_str = R"(
+HloModule MixedPipeline.0
+
+ENTRY main {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  ROOT foo = f32[] multiply(a, b)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group,
+                          ParseModuleGroup({module_0_str, module_1_str}));
+
+  HloPassPipeline pipeline(TestName());
+  pipeline.AddPass<BazToQuxModuleGroupPass>();
+  pipeline.AddPass<FooToBarModulePass>();
+
+  HloInstruction* root0 =
+      module_group.module(0).entry_computation()->root_instruction();
+  HloInstruction* root1 =
+      module_group.module(1).entry_computation()->root_instruction();
+  EXPECT_EQ(root0->name(), "baz");
+  EXPECT_EQ(root1->name(), "foo");
+
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          pipeline.RunOnModuleGroup(&module_group));
+  EXPECT_TRUE(changed);
+
+  EXPECT_EQ(root0->name(), "qux");
+  EXPECT_EQ(root1->name(), "bar");
+}
+
+TEST_F(HloPassPipelineTest, InvariantChecker) {
+  const string module_str = R"(
+HloModule InvariantChecker
+
+ENTRY main {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  ROOT foo = f32[] multiply(a, b)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(module_str));
+  {
+    // Run a pipeline with just the invariant checker. It should not fail
+    // because there is no 'bar' instruction in the module.
+    HloPassPipeline pipeline(TestName());
+    pipeline.AddInvariantChecker<BarBlowerUpper>();
+
+    TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+    EXPECT_FALSE(changed);
+  }
+
+  {
+    // Run a pipeline which renames 'foo' to 'bar' then an invariant checker
+    // which fails if there is an instruction named 'bar'.
+    HloPassPipeline pipeline(TestName());
+    pipeline.AddInvariantChecker<BarBlowerUpper>();
+    pipeline.AddPass<FooToBarModulePass>();
+
+    Status status = pipeline.Run(module.get()).status();
+    ASSERT_IS_NOT_OK(status);
+    EXPECT_THAT(status.error_message(),
+                ::testing::HasSubstr("Module has instruction named bar"));
+    EXPECT_THAT(status.error_message(),
+                ::testing::HasSubstr("Failed after foo2bar"));
+  }
+
+  {
+    // Run the invariant-checker only pipeline again. It should fail this time.
+    HloPassPipeline pipeline(TestName());
+    pipeline.AddInvariantChecker<BarBlowerUpper>();
+
+    Status status = pipeline.Run(module.get()).status();
+    ASSERT_IS_NOT_OK(status);
+    EXPECT_THAT(status.error_message(),
+                ::testing::HasSubstr("Module has instruction named bar"));
+    EXPECT_THAT(status.error_message(),
+                ::testing::HasSubstr("Failed after pipeline-start"));
+  }
+}
+
+TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) {
+  // Running a module group pass on a module should produce an error.
+  const string module_str = R"(
+HloModule ModuleGroupPassOnModule
+
+ENTRY main {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  ROOT foo = f32[] multiply(a, b)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+                          ParseAndReturnVerifiedModule(module_str));
+  HloPassPipeline pipeline(TestName());
+  pipeline.AddPass<BazToQuxModuleGroupPass>();
+
+  Status status = pipeline.Run(module.get()).status();
+  ASSERT_IS_NOT_OK(status);
+  EXPECT_THAT(
+      status.error_message(),
+      ::testing::HasSubstr("Module group pass cannot be run on a module"));
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index bd6dd79..a438671 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -1198,6 +1198,12 @@
           << HumanReadableNumBytes(memory_limit_bytes_);
   XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
 
+  // Initialize pass object state.
+  computation_peak_memory_.clear();
+  rematerialized_computations_.clear();
+  instructions_rematerialized_ = 0;
+  net_instructions_added_ = 0;
+
   TF_RET_CHECK(module->has_schedule());
   TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
 
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index e2aaf18..7330d73 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -33,7 +33,7 @@
 // CSE will undo the effects of this optimization and should not be run after
 // this pass. In general, this pass should be run very late, immediately before
 // code generation.
-class HloRematerialization : public HloPassInterface {
+class HloRematerialization : public HloModulePass {
  public:
   using ShapeSizeFunction = std::function<int64(const Shape&)>;
 
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
index d1cf644..fa34bdd 100644
--- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
@@ -22,7 +22,7 @@
 
 // Unify subcomputations of a `HloModule`: if any computations are equal, choose
 // one arbitrarily to use and delete the others.
-class HloSubcomputationUnification : public HloPassInterface {
+class HloSubcomputationUnification : public HloModulePass {
  public:
   absl::string_view name() const override {
     return "subcomputation-unification";
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 773fc7d..8549487 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -131,6 +131,7 @@
       CHECK_LE(operand_number, 2);
       return operand_number == 0 || index.empty();
 
+    case HloOpcode::kDomain:
     case HloOpcode::kTuple:
       // These instructions always pass through their operands transparently.
       return false;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 50f39cb..6eb6658 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1057,6 +1057,7 @@
 }  // namespace
 
 StatusOr<bool> HloVerifier::Run(HloModule* module) {
+  TF_RET_CHECK(!module->name().empty());
   TF_RETURN_IF_ERROR(VerifyHloStructure(module));
   TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
 
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 42e3027..0cde4a3 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -151,7 +151,7 @@
 
 // HLO pass that verifies invariants of HLO instructions for each computation in
 // the module.
-class HloVerifier : public HloPassInterface {
+class HloVerifier : public HloModulePass {
  public:
   using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
 
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
index 85bb4a8..9c48b7d 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
@@ -25,7 +25,7 @@
 
 // Pass which replaces all implicit broadcasts with their equivalent sequence of
 // explicit broadcast and reshape instructions.
-class ImplicitBroadcastRemover : public HloPassInterface {
+class ImplicitBroadcastRemover : public HloModulePass {
  public:
   ImplicitBroadcastRemover() {}
   ~ImplicitBroadcastRemover() override {}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index df9cbab..3e238f9 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -366,7 +366,7 @@
 // A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
 // This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to
 // unconditionally add to the regular HLO pass pipeline.
-class IndexedArrayAnalysisPrinterPass : public HloPassInterface {
+class IndexedArrayAnalysisPrinterPass : public HloModulePass {
  public:
   absl::string_view name() const override;
   StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h
index efa8ed3..e20af08 100644
--- a/tensorflow/compiler/xla/service/inliner.h
+++ b/tensorflow/compiler/xla/service/inliner.h
@@ -24,7 +24,7 @@
 // A pass which performs inlining. Which can result, for example, in functions
 // that were previously being mapped by Map instead directly applied to the
 // forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
-class Inliner : public HloPassInterface {
+class Inliner : public HloModulePass {
  public:
   ~Inliner() override = default;
   absl::string_view name() const override { return "inline"; }
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index c1fde8e..7e1196f 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -56,7 +56,7 @@
 // with the intent that the loops which compute their values will be fused in
 // code generation. Derived classes define ShouldFuse method to select which
 // instructions to fuse.
-class InstructionFusion : public HloPassInterface {
+class InstructionFusion : public HloModulePass {
  public:
   explicit InstructionFusion(
       std::function<bool(const HloInstruction& instruction)> is_expensive,
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index cf54503..e29c199 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -281,7 +281,7 @@
 
 // HLO pass which assigns layouts to all instructions in the HLO module while
 // satisfying all necessary invariants and minimizing cost.
-class LayoutAssignment : public HloPassInterface {
+class LayoutAssignment : public HloModulePass {
  public:
   // entry_computation_layout is modified to populate a layout for the result in
   // the case that no particular layout is requested.
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index d2c5265..0344626 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -44,7 +44,7 @@
 //  Note that the reachability map is updated based on the original computation.
 //  This works because the reachability is monotonically increasing with
 //  instruction fusion.
-class MultiOutputFusion : public HloPassInterface {
+class MultiOutputFusion : public HloModulePass {
  public:
   MultiOutputFusion(int64 fuel) : fuel_(fuel) {}
 
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index bd8fb17..ac2f796 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -39,8 +39,10 @@
 }
 
 /*static*/ string NameUniquer::GetSanitizedName(const string& name) {
+  if (name.empty()) {
+    return "";
+  }
   string result = name;
-  CHECK(!result.empty()) << "name should not be empty";
   char c = static_cast<unsigned char>(result[0]);
   if (!isalpha(c) && c != '_') {
     result[0] = '_';
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index 4869db7..7d4d62e 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -17,8 +17,12 @@
 #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
 
 #include "absl/strings/string_view.h"
+#include "absl/utility/utility.h"
 #include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 
@@ -228,8 +232,46 @@
   LayoutType** matched_layout_;
 };
 
+template <typename Item, typename... Patterns>
+class AnyOfPattern {
+ public:
+  explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
+
+  bool Match(const Item* item) const {
+    return MatchImpl(item, std::integral_constant<size_t, 0>());
+  }
+
+  bool Match(Item* item) const {
+    return MatchImpl(item, std::integral_constant<size_t, 0>());
+  }
+
+ private:
+  template <typename ItemType, size_t index>
+  bool MatchImpl(ItemType* item, std::integral_constant<size_t, index>) const {
+    return std::get<index>(patterns_).Match(item) ||
+           MatchImpl(item, std::integral_constant<size_t, index + 1>());
+  }
+
+  template <typename ItemType>
+  bool MatchImpl(ItemType* item,
+                 std::integral_constant<size_t, sizeof...(Patterns)>) const {
+    return false;
+  }
+
+  std::tuple<Patterns...> patterns_;
+};
 }  // namespace detail
 
+// Returns a pattern that represents the logical disjunction of the input
+// patterns. The returned pattern matches from left to right, and stops on the
+// first match.
+template <typename Item, typename... Patterns>
+detail::AnyOfPattern<typename std::remove_const<Item>::type, Patterns...> AnyOf(
+    const Patterns&... patterns) {
+  return detail::AnyOfPattern<typename std::remove_const<Item>::type,
+                              Patterns...>(patterns...);
+}
+
 // Creates a layout pattern that will capture the matched layout in the
 // argument.
 inline constexpr detail::LayoutPattern<const ::xla::Layout,
@@ -752,6 +794,27 @@
   int64 tuple_index_;
 };
 
+template <typename Previous, typename ItemType, typename Predicate>
+class HloPredicatePatternImpl {
+ public:
+  explicit HloPredicatePatternImpl(const Previous& previous, Predicate pred)
+      : previous_(previous), pred_(std::move(pred)) {}
+
+  bool Match(const ItemType* item) const {
+    return previous_.Match(item) && pred_(item);
+  }
+
+  bool Match(ItemType* item) const {
+    return previous_.Match(item) && pred_(item);
+  }
+
+ private:
+  Previous previous_;
+  Predicate pred_;
+};
+
+struct PatternFriend;
+
 // A pattern that matches HloInstructions.
 template <typename HloInstructionType, typename Impl>
 class HloInstructionPattern {
@@ -879,6 +942,21 @@
   }
 
  private:
+  template <typename Predicate>
+  constexpr HloInstructionPattern<
+      HloInstructionType,
+      HloPredicatePatternImpl<
+          Impl, typename std::remove_const<HloInstructionType>::type,
+          Predicate>>
+  WithPredicate(Predicate pred) const {
+    using NewImplType = HloPredicatePatternImpl<
+        Impl, typename std::remove_const<HloInstructionType>::type, Predicate>;
+    return HloInstructionPattern<HloInstructionType, NewImplType>(
+        NewImplType(impl_, std::move(pred)), matched_inst_);
+  }
+
+  friend struct PatternFriend;
+
   Impl impl_;
   HloInstructionType** matched_inst_;
 };
@@ -1005,31 +1083,50 @@
         .WithOperand(0, std::forward<Lhs>(lhs))                             \
         .WithOperand(1, std::forward<Rhs>(rhs));                            \
   }
-XLA_BINOP_PATTERN(Add)
+
+#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME)                                 \
+  XLA_BINOP_PATTERN(NAME)                                                   \
+                                                                            \
+  template <typename Lhs, typename Rhs>                                     \
+  inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs)                          \
+      ->decltype(AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs))) {   \
+    return AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs));           \
+  }                                                                         \
+                                                                            \
+  template <typename HloInstructionType, typename Lhs, typename Rhs>        \
+  inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs,  \
+                             Rhs&& rhs)                                     \
+      ->decltype(AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs),    \
+                                           NAME(matched_inst, rhs, lhs))) { \
+    return AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs),          \
+                                     NAME(matched_inst, rhs, lhs));         \
+  }
+XLA_COMMUTATIVE_BINOP_PATTERN(Add)
 XLA_BINOP_PATTERN(Atan2)
 XLA_BINOP_PATTERN(Divide)
 XLA_BINOP_PATTERN(Complex)
 XLA_BINOP_PATTERN(Dot)
-XLA_BINOP_PATTERN(Eq)
+XLA_COMMUTATIVE_BINOP_PATTERN(Eq)
 XLA_BINOP_PATTERN(Gather)
 XLA_BINOP_PATTERN(Ge)
 XLA_BINOP_PATTERN(Gt)
 XLA_BINOP_PATTERN(Le)
 XLA_BINOP_PATTERN(Lt)
-XLA_BINOP_PATTERN(Maximum)
-XLA_BINOP_PATTERN(Minimum)
-XLA_BINOP_PATTERN(Multiply)
-XLA_BINOP_PATTERN(Ne)
+XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
+XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
+XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
+XLA_COMMUTATIVE_BINOP_PATTERN(Ne)
 XLA_BINOP_PATTERN(Outfeed)
 XLA_BINOP_PATTERN(Power)
 XLA_BINOP_PATTERN(Remainder)
 XLA_BINOP_PATTERN(Send)
 XLA_BINOP_PATTERN(Subtract)
-XLA_BINOP_PATTERN(And)
-XLA_BINOP_PATTERN(Or)
+XLA_COMMUTATIVE_BINOP_PATTERN(And)
+XLA_COMMUTATIVE_BINOP_PATTERN(Or)
 XLA_BINOP_PATTERN(ShiftLeft)
 XLA_BINOP_PATTERN(ShiftRightArithmetic)
 XLA_BINOP_PATTERN(ShiftRightLogical)
+#undef XLA_COMMUTATIVE_BINOP_PATTERN
 #undef XLA_BINOP_PATTERN
 
 // Helpers for ternary instructions.
@@ -1070,6 +1167,30 @@
 XLA_TERNOP_PATTERN(Select);
 #undef XLA_TERNOP_PATTERN
 
+namespace detail {
+struct PatternFriend {
+  template <typename T>
+  static auto ConstantScalar(T constant) -> decltype(
+      Constant()
+          .WithShape(match::Shape().IsScalar())
+          .WithPredicate(
+              std::declval<std::function<bool(const HloInstruction*)>>())) {
+    std::function<bool(const HloInstruction*)> pred =
+        [constant](const HloInstruction* instr) {
+          const auto& literal = Cast<HloConstantInstruction>(instr)->literal();
+          auto status_or_const = LiteralUtil::CreateR0(constant).Convert(
+              literal.shape().element_type());
+          return status_or_const.ok() &&
+                 literal == status_or_const.ConsumeValueOrDie();
+        };
+
+    return Constant()
+        .WithShape(match::Shape().IsScalar())
+        .WithPredicate(std::move(pred));
+  }
+};
+}  // namespace detail
+
 // Helpers for matching non-constant instructions.
 inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
   return Op().IsNonConstant();
@@ -1107,6 +1228,12 @@
       .WithTupleIndex(tuple_index);
 }
 
+template <typename T>
+inline auto ConstantScalar(T constant)
+    -> decltype(detail::PatternFriend::ConstantScalar(constant)) {
+  return detail::PatternFriend::ConstantScalar(constant);
+}
+
 }  // namespace match
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
index a530581..b3a2c95 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc
+++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
@@ -211,5 +211,89 @@
   EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
 }
 
+TEST(PatternMatcherTest, AnyOf) {
+  constexpr char kModuleStr[] = R"(
+    HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
+  TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+  auto* root = hlo_module->entry_computation()->root_instruction();
+
+  EXPECT_TRUE(
+      Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
+                                               match::ConstantScalar(1))));
+  EXPECT_TRUE(
+      Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
+                                               match::ConstantScalar(0))));
+  EXPECT_FALSE(
+      Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
+                                               match::ConstantScalar(2))));
+}
+
+TEST(PatternMatcherTest, ConstantScalar) {
+  constexpr char kModuleStr[] = R"(
+    HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })";
+  TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+  auto* root = hlo_module->entry_computation()->root_instruction();
+
+  EXPECT_TRUE(Match(root, match::ConstantScalar(42)));
+  EXPECT_FALSE(Match(root, match::ConstantScalar(41)));
+  EXPECT_FALSE(Match(root, match::ConstantScalar(0)));
+}
+
+TEST(PatternMatcherTest, MultiplyAnyOrder) {
+  using match::ConstantScalar;
+  using match::MultiplyAnyOrder;
+
+  constexpr char kModuleStr[] = R"(
+    HloModule test_module
+    ENTRY test {
+      lhs = f16[] constant(42)
+      rhs = f16[] constant(52)
+      ROOT multiply = f16[] multiply(lhs, rhs)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+  auto* root = hlo_module->entry_computation()->root_instruction();
+  const HloInstruction* instr;
+
+  EXPECT_TRUE(Match(
+      root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
+  EXPECT_TRUE(Match(
+      root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
+}
+
+TEST(PatternMatcherTest, AnyOfShortCircuit) {
+  using match::AnyOf;
+  using match::Multiply;
+  using match::Op;
+
+  constexpr char kModuleStr[] = R"(
+    HloModule test_module
+    ENTRY test {
+      lhs = f16[] constant(42)
+      rhs = f16[] constant(52)
+      ROOT multiply = f16[] multiply(lhs, rhs)
+    })";
+  TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+  auto* root = hlo_module->entry_computation()->root_instruction();
+
+  {
+    const HloInstruction* mul = nullptr;
+    const HloInstruction* any = nullptr;
+
+    ASSERT_TRUE(Match(
+        root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
+    EXPECT_NE(nullptr, mul);
+    EXPECT_EQ(nullptr, any);
+  }
+  {
+    const HloInstruction* mul = nullptr;
+    const HloInstruction* any = nullptr;
+
+    ASSERT_TRUE(Match(
+        root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
+    EXPECT_NE(nullptr, any);
+    EXPECT_EQ(nullptr, mul);
+  }
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
index 256b231..4bb2242 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
@@ -29,7 +29,7 @@
 // HLO pass which inserts reduce-precision instructions into the HLO graph, for
 // purposes of experimenting with the effects of reduced-precision storage of
 // intermediate values.
-class ReducePrecisionInsertion : public HloPassInterface {
+class ReducePrecisionInsertion : public HloModulePass {
   using InstructionFilterFunction = std::function<bool(const HloInstruction*)>;
 
  public:
diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h
index 1e86a08..a3db439 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.h
+++ b/tensorflow/compiler/xla/service/reshape_mover.h
@@ -24,7 +24,7 @@
 // This now only moves them outputward across elementwise ops all whose operands
 // are equivalent Reshapes or Transposes, but in future could potentially move
 // them inputward also.
-class ReshapeMover : public HloPassInterface {
+class ReshapeMover : public HloModulePass {
  public:
   absl::string_view name() const override { return "reshape-mover"; }
 
diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h
index 14f062c..559a85d 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.h
+++ b/tensorflow/compiler/xla/service/scatter_expander.h
@@ -20,7 +20,7 @@
 
 namespace xla {
 
-class ScatterExpander : public HloPassInterface {
+class ScatterExpander : public HloModulePass {
  public:
   absl::string_view name() const override { return "scatter_expander"; }
   StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 922ebdf..b27a92f 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -812,7 +812,7 @@
   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
                       HloModule::CreateFromProto(module_proto, *module_config));
 
-  TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
+  TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module));
 
   TF_ASSIGN_OR_RETURN(
       module, backend->compiler()->RunHloPasses(std::move(module), executor,
@@ -1160,7 +1160,7 @@
   return replicas;
 }
 
-Status Service::MaybeDumpHloModule(const HloModule& module) const {
+Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const {
   const string xla_dump_unoptimized_hlo_proto_to =
       module.config().debug_options().xla_dump_unoptimized_hlo_proto_to();
   if (xla_dump_unoptimized_hlo_proto_to.empty()) {
@@ -1168,7 +1168,8 @@
   }
   HloProto proto = MakeHloProto(module);
   return protobuf_util::DumpProtoToDirectory(
-      proto, xla_dump_unoptimized_hlo_proto_to, module.name());
+      proto, xla_dump_unoptimized_hlo_proto_to,
+      StrCat(module.name(), ".unoptimized"));
 }
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 44c5248..1f62fad 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -271,7 +271,9 @@
   StatusOr<std::vector<se::StreamExecutor*>> Replicas(
       const Backend& backend, const DeviceHandle& device_handle) const;
 
-  Status MaybeDumpHloModule(const HloModule& module) const;
+  // Dumps the (unoptimized) module given if the corresponding DebugOptions
+  // field has been set.
+  Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const;
 
   // Returns the device handle that represents the replicated device for a
   // single computation that is not model-parallelized.
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 74bdf2a..7194b2c 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1665,10 +1665,11 @@
   if (input_features != kernel_input_features * feature_group_count) {
     return InvalidArgument(
         "Expected LHS feature dimension (value %d) to match RHS "
-        "input feature dimension * feature_group_count (value %d); "
+        "input feature dimension * feature_group_count (value %d * %d = %d); "
         "got <conv>(%s, %s)\n"
         "Dimension numbers: {%s}.",
-        input_features, kernel_input_features * feature_group_count,
+        input_features, kernel_input_features, feature_group_count,
+        kernel_input_features * feature_group_count,
         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
         dnums.DebugString());
   }
diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc
index 5d1cd1c..ec09dff 100644
--- a/tensorflow/compiler/xla/service/stream_pool.cc
+++ b/tensorflow/compiler/xla/service/stream_pool.cc
@@ -28,8 +28,14 @@
       // Re-use an existing stream from the pool.
       stream = std::move(streams_.back());
       streams_.pop_back();
-      VLOG(1) << stream->DebugStreamPointers()
-              << " StreamPool reusing existing stream";
+      if (stream->ok()) {
+        VLOG(1) << stream->DebugStreamPointers()
+                << " StreamPool reusing existing stream";
+      } else {
+        VLOG(1) << stream->DebugStreamPointers()
+                << " stream was not ok, StreamPool deleting";
+        stream = nullptr;
+      }
     }
   }
 
diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc
index aaf5c37..92f4757 100644
--- a/tensorflow/compiler/xla/service/stream_pool_test.cc
+++ b/tensorflow/compiler/xla/service/stream_pool_test.cc
@@ -132,5 +132,39 @@
   EXPECT_EQ(stream2_ptr, stream3_ptr);
 }
 
+TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) {
+  std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
+  StreamPool pool;
+
+  // Borrow a stream.
+  StreamPool::Ptr stream1 = pool.BorrowStream(executor.get());
+  EXPECT_TRUE(stream1->ok());
+
+  // Return the stream, but hold a handle to it.
+  se::Stream* stream1_ptr = stream1.get();
+  stream1 = nullptr;
+
+  // Now stream1 is back in the pool, force an error on the stream. Here we call
+  // a method that requires DNN support, which we know the Host platform doesn't
+  // support.
+  stream1_ptr->ThenDepthConcatenate({}, {}, nullptr);
+  EXPECT_FALSE(stream1_ptr->ok());
+
+  // Borrow stream2.
+  StreamPool::Ptr stream2 = pool.BorrowStream(executor.get());
+  EXPECT_TRUE(stream2->ok());
+
+  // The underlying streams should be different. They would have been
+  // the same, but since we forced an error on stream1, it cannot be
+  // put back into the pool. Sadly we can't just check:
+  //    EXPECT_NE(stream1_ptr, stream2_ptr);
+  //
+  // The above should hold logically, but it may fail if the new
+  // stream instance allocated for stream2 happens to reside in the
+  // same memory address as stream1, which has been deleted.
+  //
+  // The check that stream2->ok() serves as a good-enough check.
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h
index 3e5aa2d..f95f982 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.h
+++ b/tensorflow/compiler/xla/service/transpose_folding.h
@@ -23,7 +23,7 @@
 
 // HLO pass that folds transpose operators into Dot operators, where the Dot
 // operator is implemented by a GEMM kernel that can transpose its inputs.
-class TransposeFolding : public HloPassInterface {
+class TransposeFolding : public HloModulePass {
  public:
   using OperandIndices = std::vector<int64>;
 
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h
index 8c91d6e..e126a53 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.h
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.h
@@ -25,7 +25,7 @@
 
 // A pass which simplifies patterns of Tuple and GetTupleElement instructions in
 // the module.
-class TupleSimplifier : public HloPassInterface {
+class TupleSimplifier : public HloModulePass {
  public:
   TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {}
   explicit TupleSimplifier(bool exclude_entry_computation);
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
index 2dba7d7..577bad6 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
@@ -50,7 +50,7 @@
 // conditions as well.
 //
 // TODO(b/79121449):  We should also sink broadcasts of constants.
-class WhileLoopConstantSinking : public HloPassInterface {
+class WhileLoopConstantSinking : public HloModulePass {
  public:
   ~WhileLoopConstantSinking() override = default;
 
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
index 2cdf20c..3031899 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
@@ -25,7 +25,7 @@
 // HLO pass that rewrites while loops to hoist loop invariant instructions in
 // the while body into the computation that contains the while instruction.
 
-class WhileLoopInvariantCodeMotion : public HloPassInterface {
+class WhileLoopInvariantCodeMotion : public HloModulePass {
  public:
   // If `hoist_constants` is true then constants are always hoisted out of while
   // loop bodies.  Otherwise they are only hoisted out if they enable other
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index 6a7bfe3..9a74f22 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -252,7 +252,7 @@
   // Create the new while condition, body, and init value.
   std::unique_ptr<HloComputation> new_while_cond =
       while_cond->CloneWithReplacements(
-          make_while_computation_replacements(while_cond));
+          make_while_computation_replacements(while_cond), /*extras=*/{});
 
   std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
       while_body_replacements = make_while_computation_replacements(while_body);
@@ -265,7 +265,8 @@
   while_body_replacements.emplace(
       while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems));
   std::unique_ptr<HloComputation> new_while_body =
-      while_body->CloneWithReplacements(std::move(while_body_replacements));
+      while_body->CloneWithReplacements(std::move(while_body_replacements),
+                                        /*extras=*/{});
 
   // Add a new while_init instruction that repackages the old while_init
   // instruction's elements.  We rely on the AlgebraicSimplifier and DCE to
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h
index 78024f1..0bc5a01 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.h
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h
@@ -30,7 +30,7 @@
 //  - Elements of a while loop's tuple that the loop doesn't use are removed
 //    from the tuple.
 //
-class WhileLoopSimplifier : public HloPassInterface {
+class WhileLoopSimplifier : public HloModulePass {
  public:
   ~WhileLoopSimplifier() override {}
   absl::string_view name() const override { return "simplify-while-loops"; }
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
index a7f0e20..8729412 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
@@ -21,7 +21,7 @@
 
 // HLO pass that replaces zero sized Hlos with a zero sized constant literal.
 namespace xla {
-class ZeroSizedHloElimination : public HloPassInterface {
+class ZeroSizedHloElimination : public HloModulePass {
  public:
   StatusOr<bool> Run(HloModule* module) override;
   absl::string_view name() const override {
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 623ae39..d8bb27b 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -22,6 +22,7 @@
 #include <initializer_list>
 #include <string>
 
+#include "absl/base/macros.h"
 #include "absl/container/inlined_vector.h"
 #include "absl/types/optional.h"
 #include "absl/types/span.h"
@@ -479,8 +480,7 @@
 
   // Shorthand for testing whether a shape is of a given element type and
   // sequence of dimensions.
-  //
-  // DEPRECATED: Use Equal() instead.
+  ABSL_DEPRECATED("Use Equal() instead.")
   static bool ShapeIs(const Shape& shape, PrimitiveType element_type,
                       std::initializer_list<int64> dimensions);
 
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 30e3077..fd3e3bf 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -29,6 +29,10 @@
 load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros")
 load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "tf_cuda_tests_tags",
+)
 
 # Generate test_suites for all backends, named "${backend}_tests".
 generate_backend_suites()
@@ -150,11 +154,31 @@
         "//tensorflow/compiler/xla/service:hlo_parser",
         "//tensorflow/compiler/xla/service:hlo_verifier",
         "//tensorflow/core:lib",
-        "//tensorflow/core:test",
+        "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/memory",
     ],
 )
 
+tf_cc_test(
+    name = "hlo_verified_test_base_test",
+    srcs = ["hlo_verified_test_base_test.cc"],
+    deps = [
+        ":hlo_test_base",
+        ":hlo_verified_test_base",
+        ":test_macros_cpu",
+        ":test_utils",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/compiler/xla/client:xla_computation",
+        "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_parser",
+        "//tensorflow/compiler/xla/service:hlo_verifier",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+    ],
+)
+
 tf_cc_binary(
     name = "local_client_aot_test_helper",
     srcs = ["local_client_aot_test_helper.cc"],
@@ -1797,7 +1821,7 @@
 tf_cc_test(
     name = "llvm_compiler_test",
     srcs = ["llvm_compiler_test.cc"],
-    tags = ["requires-gpu-sm35"],
+    tags = tf_cuda_tests_tags(),
     deps = [
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:test_helpers",
@@ -2096,7 +2120,7 @@
     name = "sample_file_test",
     srcs = ["sample_file_test.cc"],
     data = ["isolated_convolution.hlo"],
-    tags = ["requires-gpu-sm35"],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":hlo_test_base",
         "//tensorflow/compiler/xla:test",
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl
index 53f2c3b..cc65a89 100644
--- a/tensorflow/compiler/xla/tests/build_defs.bzl
+++ b/tensorflow/compiler/xla/tests/build_defs.bzl
@@ -3,256 +3,266 @@
 load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
 load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins")
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "tf_cuda_tests_tags",
+)
 
 all_backends = ["cpu", "gpu"] + plugins.keys()
 
 def filter_backends(backends):
-  """Removes "gpu" from a backend list if CUDA is not enabled.
+    """Removes "gpu" from a backend list if CUDA is not enabled.
 
-  This allows us to simply hardcode lists including "gpu" here and in the
-  BUILD file, without causing failures when CUDA isn't enabled.'
+    This allows us to simply hardcode lists including "gpu" here and in the
+    BUILD file, without causing failures when CUDA isn't enabled.'
 
-  Args:
-    backends: A list of backends to filter.
+    Args:
+      backends: A list of backends to filter.
 
-  Returns:
-    The filtered list of backends.
-  """
-  if cuda_is_configured():
-    return backends
-  else:
-    return [backend for backend in backends if backend != "gpu"]
-
-
-def xla_test(name,
-             srcs,
-             deps,
-             xla_test_library_deps=[],
-             backends=[],
-             blacklisted_backends=[],
-             args=[],
-             tags=[],
-             copts=[],
-             data=[],
-             backend_tags={},
-             backend_args={},
-             **kwargs):
-  """Generates cc_test targets for the given XLA backends.
-
-  This rule generates a cc_test target for one or more XLA backends and also a
-  platform-agnostic cc_library rule. The arguments are identical to cc_test with
-  two additions: 'backends' and 'backend_args'. 'backends' specifies the
-  backends to generate tests for ("cpu", "gpu"), and
-  'backend_args'/'backend_tags' specifies backend-specific args parameters to
-  use when generating the cc_test.
-
-  The name of the cc_tests are the provided name argument with the backend name
-  appended, and the cc_library target name is the provided name argument with
-  "_lib" appended. For example, if name parameter is "foo_test", then the cpu
-  test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
-
-  The cc_library target can be used to link with other plugins outside of
-  xla_test.
-
-  The build rule also defines a test suite ${name} which includes the tests for
-  each of the supported backends.
-
-  Each generated cc_test target has a tag indicating which backend the test is
-  for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
-  tags can be used to gather tests for a particular backend into a test_suite.
-
-  Examples:
-
-    # Generates the targets: foo_test_cpu and foo_test_gpu.
-    xla_test(
-        name = "foo_test",
-        srcs = ["foo_test.cc"],
-        backends = ["cpu", "gpu"],
-        deps = [...],
-    )
-
-    # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
-    # includes the additional arg "--special_cpu_flag".
-    xla_test(
-        name = "bar_test",
-        srcs = ["bar_test.cc"],
-        backends = ["cpu", "gpu"],
-        backend_args = {"cpu": ["--special_cpu_flag"]}
-        deps = [...],
-    )
-
-  The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
-  to the value 1 where ${BACKEND} is the uppercase name of the backend.
-
-  Args:
-    name: Name of the target.
-    srcs: Sources for the target.
-    deps: Dependencies of the target.
-    xla_test_library_deps: If set, the generated test targets will depend on the
-      respective cc_libraries generated by the xla_test_library rule.
-    backends: A list of backends to generate tests for. Supported values: "cpu",
-      "gpu". If this list is empty, the test will be generated for all supported
-      backends.
-    blacklisted_backends: A list of backends to NOT generate tests for.
-    args: Test arguments for the target.
-    tags: Tags for the target.
-    copts: Additional copts to pass to the build.
-    data: Additional data to pass to the build.
-    backend_tags: A dict mapping backend name to list of additional tags to
-      use for that target.
-    backend_args: A dict mapping backend name to list of additional args to
-      use for that target.
-    **kwargs: Additional keyword arguments to pass to native.cc_test.
-  """
-  test_names = []
-  if not backends:
-    backends = all_backends
-
-  backends = [backend for backend in backends
-              if backend not in blacklisted_backends]
-
-  native.cc_library(
-      name="%s_lib" % name,
-      srcs=srcs,
-      copts=copts,
-      testonly=True,
-      deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
-  )
-
-  for backend in filter_backends(backends):
-    test_name = "%s_%s" % (name, backend)
-    this_backend_tags = ["xla_%s" % backend]
-    this_backend_copts = []
-    this_backend_args = backend_args.get(backend, [])
-    this_backend_data = []
-    if backend == "cpu":
-      backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
-      backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
-    elif backend == "gpu":
-      backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
-      backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
-      this_backend_tags += ["requires-gpu-sm35"]
-    elif backend in plugins:
-      backend_deps = []
-      backend_deps += plugins[backend]["deps"]
-      this_backend_copts += plugins[backend]["copts"]
-      this_backend_tags += plugins[backend]["tags"]
-      this_backend_args += plugins[backend]["args"]
-      this_backend_data += plugins[backend]["data"]
+    Returns:
+      The filtered list of backends.
+    """
+    if cuda_is_configured():
+        return backends
     else:
-      fail("Unknown backend %s" % backend)
+        return [backend for backend in backends if backend != "gpu"]
 
-    if xla_test_library_deps:
-      for lib_dep in xla_test_library_deps:
-        backend_deps += ["%s_%s" % (lib_dep, backend)]
+def xla_test(
+        name,
+        srcs,
+        deps,
+        xla_test_library_deps = [],
+        backends = [],
+        blacklisted_backends = [],
+        args = [],
+        tags = [],
+        copts = [],
+        data = [],
+        backend_tags = {},
+        backend_args = {},
+        **kwargs):
+    """Generates cc_test targets for the given XLA backends.
 
-    tf_cc_test(
-        name=test_name,
-        srcs=srcs,
-        tags=tags + backend_tags.get(backend, []) + this_backend_tags,
-        extra_copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
-        this_backend_copts,
-        args=args + this_backend_args,
-        deps=deps + backend_deps,
-        data=data + this_backend_data,
-        **kwargs)
+    This rule generates a cc_test target for one or more XLA backends and also a
+    platform-agnostic cc_library rule. The arguments are identical to cc_test with
+    two additions: 'backends' and 'backend_args'. 'backends' specifies the
+    backends to generate tests for ("cpu", "gpu"), and
+    'backend_args'/'backend_tags' specifies backend-specific args parameters to
+    use when generating the cc_test.
 
-    test_names.append(test_name)
+    The name of the cc_tests are the provided name argument with the backend name
+    appended, and the cc_library target name is the provided name argument with
+    "_lib" appended. For example, if name parameter is "foo_test", then the cpu
+    test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
 
-  native.test_suite(name=name, tests=test_names)
+    The cc_library target can be used to link with other plugins outside of
+    xla_test.
 
-def xla_test_library(name,
-                     srcs,
-                     hdrs=[],
-                     deps=[],
-                     backends=[]):
-  """Generates cc_library targets for the given XLA backends.
+    The build rule also defines a test suite ${name} which includes the tests for
+    each of the supported backends.
 
-  This rule forces the sources to be compiled for each backend so that the
-  backend specific macros could expand correctly. It's useful when test targets
-  in different directories referring to the same sources but test with different
-  arguments.
+    Each generated cc_test target has a tag indicating which backend the test is
+    for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
+    tags can be used to gather tests for a particular backend into a test_suite.
 
-  Examples:
+    Examples:
 
-    # Generates the targets: foo_test_library_cpu and foo_test_gpu.
-    xla_test_library(
-        name = "foo_test_library",
-        srcs = ["foo_test.cc"],
-        backends = ["cpu", "gpu"],
-        deps = [...],
-    )
-    # Then use the xla_test rule to generate test targets:
-    xla_test(
-        name = "foo_test",
-        srcs = [],
-        backends = ["cpu", "gpu"],
-        deps = [...],
-        xla_test_library_deps = [":foo_test_library"],
-    )
+      # Generates the targets: foo_test_cpu and foo_test_gpu.
+      xla_test(
+          name = "foo_test",
+          srcs = ["foo_test.cc"],
+          backends = ["cpu", "gpu"],
+          deps = [...],
+      )
 
-  Args:
-    name: Name of the target.
-    srcs: Sources for the target.
-    hdrs: Headers for the target.
-    deps: Dependencies of the target.
-    backends: A list of backends to generate libraries for.
-      Supported values: "cpu", "gpu". If this list is empty, the
-      library will be generated for all supported backends.
-  """
+      # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
+      # includes the additional arg "--special_cpu_flag".
+      xla_test(
+          name = "bar_test",
+          srcs = ["bar_test.cc"],
+          backends = ["cpu", "gpu"],
+          backend_args = {"cpu": ["--special_cpu_flag"]}
+          deps = [...],
+      )
 
-  if not backends:
-    backends = all_backends
+    The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
+    to the value 1 where ${BACKEND} is the uppercase name of the backend.
 
-  for backend in filter_backends(backends):
-    this_backend_copts = []
-    if backend in ["cpu", "gpu"]:
-      backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
-    elif backend in plugins:
-      backend_deps = plugins[backend]["deps"]
-      this_backend_copts += plugins[backend]["copts"]
-    else:
-      fail("Unknown backend %s" % backend)
+    Args:
+      name: Name of the target.
+      srcs: Sources for the target.
+      deps: Dependencies of the target.
+      xla_test_library_deps: If set, the generated test targets will depend on the
+        respective cc_libraries generated by the xla_test_library rule.
+      backends: A list of backends to generate tests for. Supported values: "cpu",
+        "gpu". If this list is empty, the test will be generated for all supported
+        backends.
+      blacklisted_backends: A list of backends to NOT generate tests for.
+      args: Test arguments for the target.
+      tags: Tags for the target.
+      copts: Additional copts to pass to the build.
+      data: Additional data to pass to the build.
+      backend_tags: A dict mapping backend name to list of additional tags to
+        use for that target.
+      backend_args: A dict mapping backend name to list of additional args to
+        use for that target.
+      **kwargs: Additional keyword arguments to pass to native.cc_test.
+    """
+    test_names = []
+    if not backends:
+        backends = all_backends
+
+    backends = [
+        backend
+        for backend in backends
+        if backend not in blacklisted_backends
+    ]
 
     native.cc_library(
-        name = "%s_%s" % (name, backend),
+        name = "%s_lib" % name,
         srcs = srcs,
+        copts = copts,
         testonly = True,
-        hdrs = hdrs,
-        copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()]
-        + this_backend_copts,
-        deps = deps + backend_deps,
+        deps = deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
     )
 
+    for backend in filter_backends(backends):
+        test_name = "%s_%s" % (name, backend)
+        this_backend_tags = ["xla_%s" % backend]
+        this_backend_copts = []
+        this_backend_args = backend_args.get(backend, [])
+        this_backend_data = []
+        if backend == "cpu":
+            backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
+            backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
+        elif backend == "gpu":
+            backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
+            backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
+            this_backend_tags += tf_cuda_tests_tags()
+        elif backend in plugins:
+            backend_deps = []
+            backend_deps += plugins[backend]["deps"]
+            this_backend_copts += plugins[backend]["copts"]
+            this_backend_tags += plugins[backend]["tags"]
+            this_backend_args += plugins[backend]["args"]
+            this_backend_data += plugins[backend]["data"]
+        else:
+            fail("Unknown backend %s" % backend)
 
-def generate_backend_suites(backends=[]):
-  if not backends:
-    backends = all_backends
-  for backend in filter_backends(backends):
-    native.test_suite(name="%s_tests" % backend,
-                      tags = ["xla_%s" % backend])
+        if xla_test_library_deps:
+            for lib_dep in xla_test_library_deps:
+                backend_deps += ["%s_%s" % (lib_dep, backend)]
 
+        tf_cc_test(
+            name = test_name,
+            srcs = srcs,
+            tags = tags + backend_tags.get(backend, []) + this_backend_tags,
+            extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
+                          this_backend_copts,
+            args = args + this_backend_args,
+            deps = deps + backend_deps,
+            data = data + this_backend_data,
+            **kwargs
+        )
 
-def generate_backend_test_macros(backends=[]):
-  if not backends:
-    backends = all_backends
-  for backend in filter_backends(backends):
-    manifest = ""
-    if backend in plugins:
-      manifest = plugins[backend]["disabled_manifest"]
+        test_names.append(test_name)
 
-    native.cc_library(
-        name="test_macros_%s" % backend,
-        testonly = True,
-        srcs = ["test_macros.cc"],
-        hdrs = ["test_macros.h"],
-        copts = [
-          "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
-          "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
-        ],
-        deps = [
-            "//tensorflow/compiler/xla:types",
-            "//tensorflow/core:lib",
-            "//tensorflow/core:regexp_internal",
-            "//tensorflow/core:test",
-        ])
+    native.test_suite(name = name, tests = test_names)
+
+def xla_test_library(
+        name,
+        srcs,
+        hdrs = [],
+        deps = [],
+        backends = []):
+    """Generates cc_library targets for the given XLA backends.
+
+    This rule forces the sources to be compiled for each backend so that the
+    backend specific macros could expand correctly. It's useful when test targets
+    in different directories referring to the same sources but test with different
+    arguments.
+
+    Examples:
+
+      # Generates the targets: foo_test_library_cpu and foo_test_gpu.
+      xla_test_library(
+          name = "foo_test_library",
+          srcs = ["foo_test.cc"],
+          backends = ["cpu", "gpu"],
+          deps = [...],
+      )
+      # Then use the xla_test rule to generate test targets:
+      xla_test(
+          name = "foo_test",
+          srcs = [],
+          backends = ["cpu", "gpu"],
+          deps = [...],
+          xla_test_library_deps = [":foo_test_library"],
+      )
+
+    Args:
+      name: Name of the target.
+      srcs: Sources for the target.
+      hdrs: Headers for the target.
+      deps: Dependencies of the target.
+      backends: A list of backends to generate libraries for.
+        Supported values: "cpu", "gpu". If this list is empty, the
+        library will be generated for all supported backends.
+    """
+
+    if not backends:
+        backends = all_backends
+
+    for backend in filter_backends(backends):
+        this_backend_copts = []
+        if backend in ["cpu", "gpu"]:
+            backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
+        elif backend in plugins:
+            backend_deps = plugins[backend]["deps"]
+            this_backend_copts += plugins[backend]["copts"]
+        else:
+            fail("Unknown backend %s" % backend)
+
+        native.cc_library(
+            name = "%s_%s" % (name, backend),
+            srcs = srcs,
+            testonly = True,
+            hdrs = hdrs,
+            copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
+                    this_backend_copts,
+            deps = deps + backend_deps,
+        )
+
+def generate_backend_suites(backends = []):
+    if not backends:
+        backends = all_backends
+    for backend in filter_backends(backends):
+        native.test_suite(
+            name = "%s_tests" % backend,
+            tags = ["xla_%s" % backend],
+        )
+
+def generate_backend_test_macros(backends = []):
+    if not backends:
+        backends = all_backends
+    for backend in filter_backends(backends):
+        manifest = ""
+        if backend in plugins:
+            manifest = plugins[backend]["disabled_manifest"]
+
+        native.cc_library(
+            name = "test_macros_%s" % backend,
+            testonly = True,
+            srcs = ["test_macros.cc"],
+            hdrs = ["test_macros.h"],
+            copts = [
+                "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
+                "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
+            ],
+            deps = [
+                "//tensorflow/compiler/xla:types",
+                "//tensorflow/core:lib",
+                "//tensorflow/core:regexp_internal",
+                "//tensorflow/core:test",
+            ],
+        )
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index 8f86c52..8bd0a72 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -21,64 +21,68 @@
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/test.h"
 
 namespace xla {
 
+Status VerifiedHloModule::Verify() {
+  if (computation_count() == 0) {
+    // The computation was never built. Nothing to verify.
+    return Status::OK();
+  }
+  return verifier_.Run(this).status();
+}
+
+void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
+  Status status = Verify();
+  if (!status.ok()) {
+    ADD_FAILURE() << "HloVerifier failed on module " << name()
+                  << (message.empty() ? "" : absl::StrCat(" (", message, ")"))
+                  << ": " << status;
+  }
+}
+
 HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
                                          bool allow_mixed_precision)
     : HloTestBase(
           /*verifier_layout_sensitive=*/layout_sensitive,
-          /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {}
-
-HloVerifiedTestBase::~HloVerifiedTestBase() {
-  // We can't call the ASSERT or EXPECT test macros in destructors, so we
-  // perform HLO verification in TearDown, and use the CHECK here to ensure
-  // users don't accidentally override the verification.
-  CHECK(tear_down_called_)
-      << "TearDown was never called; subclasses of HloVerifiedTestBase that "
-      << "override TearDown must call the superclass TearDown.";
-}
-
-void HloVerifiedTestBase::TearDown() {
-  EXPECT_FALSE(tear_down_called_)
-      << "TearDown called more than once; it should be called exactly once.";
-  tear_down_called_ = true;
-  if (module_) {
-    VerifyModule(module_.get());
-  }
-  for (int i = 0; i < modules_.size(); ++i) {
-    VerifyModule(modules_.at(i).get());
-  }
-  HloTestBase::TearDown();
-}
-
-void HloVerifiedTestBase::VerifyModule(HloModule* module) {
-  xla::StatusOr<bool> mutated = verifier().Run(module);
-  if (!mutated.ok()) {
-    ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
-  } else {
-    EXPECT_FALSE(mutated.ValueOrDie())
-        << "HloVerifier should never mutate the HloModule";
-  }
-}
+          /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision),
+      verifier_layout_sensitive_(layout_sensitive),
+      allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {}
 
 HloModule& HloVerifiedTestBase::module() {
   if (!module_) {
-    module_ = HloTestBase::CreateNewModule();
+    module_ = CreateNewVerifiedModule(TestName());
   }
   return *module_;
 }
 
 HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) {
-  modules_.emplace_back(HloTestBase::CreateNewModule());
+  modules_.emplace_back(CreateNewVerifiedModule(name));
   return modules_.back().get();
 }
 
 void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text,
                                                const HloModuleConfig& config) {
   CHECK(!module_) << "Called ParseModule when test already has a module.";
-  TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config));
-  VerifyModule(module_.get());
+  module_ = CreateNewVerifiedModule(TestName());
+  TF_CHECK_OK(ParseHloString(hlo_text, module_.get()));
+  module_->VerifyOrAddFailure("after parsing");
 }
+
+StatusOr<std::unique_ptr<VerifiedHloModule>>
+HloVerifiedTestBase::ParseAndReturnVerifiedModule(
+    absl::string_view hlo_text, const HloModuleConfig& config) {
+  auto module = CreateNewVerifiedModule(TestName());
+  TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
+  TF_RETURN_IF_ERROR(module->Verify());
+  return std::move(module);
+}
+
+std::unique_ptr<VerifiedHloModule> HloVerifiedTestBase::CreateNewVerifiedModule(
+    const string& name) {
+  return absl::make_unique<VerifiedHloModule>(
+      name, GetModuleConfigForTest(), verifier_layout_sensitive_,
+      allow_mixed_precision_in_hlo_verifier_);
+}
+
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index 8fbc4fa..388a99b 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -20,53 +20,84 @@
 #include <memory>
 #include <utility>
 
+#include "absl/base/macros.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 
 namespace xla {
 
-// A base class for HLO tests that stores a default HloModule, and automatically
-// performs verification on that module on tear-down.
+// An HLO module derived class which verifies itself on destruction. This class
+// is intended to be used in unit tests. Any verification errors are raised via
+// ADD_FAILURE.
+class VerifiedHloModule : public HloModule {
+ public:
+  VerifiedHloModule(const string& name, const HloModuleConfig& config,
+                    bool verifier_layout_sensitive,
+                    bool allow_mixed_precision_in_hlo_verifier)
+      : HloModule(name, config),
+        verifier_(verifier_layout_sensitive,
+                  allow_mixed_precision_in_hlo_verifier) {}
+
+  ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
+
+  // Verifies the module using HloVerifier and returns the status.
+  Status Verify();
+
+  // Verifies the module and flags any error with ADD_FAILURE. 'message' is
+  // included in the failure message.
+  void VerifyOrAddFailure(const string& message);
+
+ private:
+  HloVerifier verifier_;
+};
+
+// A base class for HLO tests that stores a default VerifiedHloModule.
 class HloVerifiedTestBase : public HloTestBase {
  protected:
-  explicit HloVerifiedTestBase(bool layout_sensitive = false,
-                               bool allow_mixed_precision = false);
-  ~HloVerifiedTestBase() override;
+  HloVerifiedTestBase(bool layout_sensitive = false,
+                      bool allow_mixed_precision = false);
 
   // Constructs a default shape verifier.
   std::unique_ptr<ShapeVerifier> MakeShapeVerifier();
 
-  // Performs verification on the default HloModule returned by module().
-  // Automatically called by the testing framework for each test.
-  //
-  // REQUIRED: subclasses that override TearDown() must call this explicitly.
-  void TearDown() override;
-
   // Returns the default HloModule, lazily creating it if necessary via
   // HloTestBase::CreateNewModule().
+  ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
   HloModule& module();
+
+  ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.")
   void ParseAndVerifyModule(absl::string_view hlo_text,
                             const HloModuleConfig& config = HloModuleConfig());
 
+  // Parses the given string and returns module as a VerifiedHloModule.
+  StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
+      absl::string_view hlo_text,
+      const HloModuleConfig& config = HloModuleConfig());
+
   // Creates a new module for a test, and stores it in modules_ so it can be
   // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
   // creation of unverified modules.
+  ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
   HloModule* CreateNewModule(const string& name = TestName());
 
- private:
-  void VerifyModule(HloModule* module);
+  // Creates and returns a verified HLO module with the given name.
+  std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
+      const string& name = TestName());
 
+ private:
   // It is confusing to store modules created by module() and CreateNewModule()
   // in different fields, but it allows us to migrate tests to
   // HloVerifiedTestBase more easily, so it's a win because we can verify more
   // modules. See b/80488902.
   //
   // Lazily populated. Access via module().
-  std::unique_ptr<HloModule> module_;
-  // Populated by calls to CreateNewModule.
-  std::vector<std::unique_ptr<HloModule>> modules_;
+  std::unique_ptr<VerifiedHloModule> module_;
 
-  bool tear_down_called_ = false;
+  // Populated by calls to CreateNewModule.
+  std::vector<std::unique_ptr<VerifiedHloModule>> modules_;
+
+  bool verifier_layout_sensitive_;
+  bool allow_mixed_precision_in_hlo_verifier_;
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
new file mode 100644
index 0000000..5c0263e
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// This class includes unit tests which are expected to fail because invalid HLO
+// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to
+// include the necessary gunit parts to test this test machinery (needs the
+// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the
+// disabled tests enabled and failures can be manually compared against
+// expectations.
+class HloVerifiedTestBaseTest : public HloVerifiedTestBase {};
+
+XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) {
+  // Test shouldn't fail if no module is created at all.
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) {
+  // Use module() to lazily create an empty module, build it up, and verify no
+  // failures.
+  HloModule& hlo_module = module();
+  auto builder = HloComputation::Builder(TestName());
+  auto input = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+  builder.AddInstruction(
+      HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+  hlo_module.AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) {
+  // Use module() to lazily create an empty module and build up an invalid
+  // module.
+  HloModule& hlo_module = module();
+  auto builder = HloComputation::Builder(TestName());
+  auto input = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+  builder.AddInstruction(
+      HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+  hlo_module.AddEntryComputation(builder.Build());
+
+  *hlo_module.entry_computation()->root_instruction()->mutable_shape() =
+      ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) {
+  // Call CreateNewModule and build up a valid module.
+  HloModule* module = CreateNewModule();
+  auto builder = HloComputation::Builder(TestName());
+  auto input = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+  builder.AddInstruction(
+      HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+  module->AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) {
+  // Call CreateNewModule and build up a invalid module.
+  HloModule* module = CreateNewModule();
+  auto builder = HloComputation::Builder(TestName());
+  auto input = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+  builder.AddInstruction(
+      HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+  module->AddEntryComputation(builder.Build());
+
+  *module->entry_computation()->root_instruction()->mutable_shape() =
+      ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) {
+  const char* const hlo_string = R"(
+HloModule ParseAndVerifyModuleGood
+
+ENTRY entry {
+  x = f32[] parameter(0)
+  y = f32[] parameter(1)
+  ROOT add = f32[] add(x,y)
+}
+)";
+
+  ParseAndVerifyModule(hlo_string);
+  EXPECT_EQ(module().entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) {
+  const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+  x = f32[] parameter(0)
+  y = f32[] parameter(1)
+  ROOT add = f32[] add(x,y)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseAndReturnVerifiedModule(hlo_string));
+  EXPECT_EQ(module->entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) {
+  const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+  x = f32[] parameter(0)
+  y = f32[] parameter(1)
+  ROOT add = f32[] add(x,y)
+}
+
+RANDOM GARBAGE
+)";
+
+  ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) {
+  const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleBad
+
+ENTRY entry {
+  x = f32[] parameter(0)
+  y = f32[] parameter(1)
+  ROOT add = f32[1234] add(x,y)
+}
+)";
+
+  ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 63491a9..c25ccaf 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -1303,11 +1303,19 @@
      /*pad_high=*/{0},
      /*reducer=*/Reducer::kAdd},
 
+    // The pattern generated by inclusive scan (cumsum/cumprod).
     {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
      /*strides=*/{1},
      /*pad_low=*/{4095},
      /*pad_high=*/{0},
      /*reducer=*/Reducer::kMax},
+
+    // The pattern generated by exclusive scan (cumsum/cumprod).
+    {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
+     /*strides=*/{1},
+     /*pad_low=*/{4096},
+     /*pad_high=*/{0},
+     /*reducer=*/Reducer::kMax},
 };
 
 string R1ReduceWindowTestDataToString(
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index a40c2d7..2cc33ab 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -412,6 +412,7 @@
         R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{0, 1}}},  //
         R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{1, 0}}},  //
         R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{0, 1}}},  //
+        R2Spec{8672, 512, {{8, 0}}, {{8672, 512}}, {{542, 1}}, {{1, 0}}},   //
         R2Spec{
             511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{1, 0}}},  //
         R2Spec{
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 7abd865..8b1b9e1 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -763,9 +763,7 @@
   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
 }
 
-// Test while nodes that share the while body computation.
-// TODO(b/37245345): Fails on GPU backend.
-TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
+TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) {
   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
                                        ShapeUtil::MakeShape(F32, {10})};
   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD
index 09ab4ed..b6dcfc4 100644
--- a/tensorflow/compiler/xrt/tests/BUILD
+++ b/tensorflow/compiler/xrt/tests/BUILD
@@ -8,6 +8,10 @@
 )
 
 load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test", "tf_cc_test")
+load(
+    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "tf_cuda_tests_tags",
+)
 
 cc_library(
     name = "raw_api_test_lib",
@@ -57,7 +61,7 @@
     size = "medium",
     srcs = [],
     args = ["--xla_test_device=XLA_GPU"],
-    tags = ["requires-gpu-sm35"],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":raw_api_test_lib",
         "//tensorflow/compiler/jit:xla_gpu_device",
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index d98a249..e1af52c 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -60,7 +60,6 @@
         "//tensorflow/contrib/learn",
         "//tensorflow/contrib/legacy_seq2seq:seq2seq_py",
         "//tensorflow/contrib/libsvm",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
         "//tensorflow/contrib/linear_optimizer:sdca_ops_py",
         "//tensorflow/contrib/lite/python:lite",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 9478e42..e71b0e0 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -63,7 +63,6 @@
 from tensorflow.contrib import layers
 from tensorflow.contrib import learn
 from tensorflow.contrib import legacy_seq2seq
-from tensorflow.contrib import linalg
 from tensorflow.contrib import linear_optimizer
 from tensorflow.contrib import lookup
 from tensorflow.contrib import losses
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
index b3f5d92..9a8f62b 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
@@ -149,7 +149,7 @@
     num_devices = num_workers * num_gpus
     dev_list = ["/replica:0/task:0/device:CPU:0"
                 for _ in range(num_devices)]
-    with self.test_session():
+    with self.cached_session():
       input_tensors = self._buildInitialVars(shape, dev_list)
       un_op = lambda x: math_ops.div(
           x, constant_op.constant(num_devices, dtype=types_pb2.DT_FLOAT))
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
index 7846814..01ee870 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
@@ -43,7 +43,7 @@
 
   def testBasicBatch(self):
     """Tests that a single batched tensor executes together and only once."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       batched, index, _ = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=2,
@@ -83,7 +83,7 @@
 
   def testBatchWithPadding(self):
     """Test that batching with padding up to an allowed batch size works."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
       batched, index, _ = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=10,
@@ -113,7 +113,7 @@
 
   def testMultipleBatch(self):
     """Tests that multiple batched tensors execute together."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       batched, _, _ = batch_ops.batch(
@@ -152,7 +152,7 @@
 
   def testIllegalBatchDifferentDim0Sizes(self):
     """Tests illegally feeding tensors with different dim0 sizes."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
       batched, index, _ = batch_ops.batch(
@@ -166,7 +166,7 @@
 
   def testBasicUnbatch(self):
     """Tests that batch and unbatch work together."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       batched, index, id_t = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=10,
@@ -190,7 +190,8 @@
 
   def testBasicUnbatchV1Decorated(self):
     """Tests that the batch_function_v1 decorator works."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
+
       @batch_ops.batch_function_v1(1, 10, 100000)
       def computation(in_t):
         return in_t + 1
@@ -211,7 +212,7 @@
 
   def testBasicUnbatchDecorated(self):
     """Tests that the batch_function decorator works."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # TODO(apassos): Removing this line causes test flakiness! Ideally should
       # be investigated.
       default_inp = array_ops.placeholder_with_default(2, shape=[])  # pylint: disable=unused-variable
@@ -236,7 +237,7 @@
 
   def testBatchDecoratedWithCapturedInput(self):
     """Tests that the batch_function decorator works."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
       captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
 
@@ -260,7 +261,7 @@
 
   def testBatchFunctionOp(self):
     """Tests that the batch_function op works."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
 
       @function.Defun(dtypes.int32)
       def computation(in_t):
@@ -289,7 +290,7 @@
 
   def testBatchFunctionOpWithCapturedInput(self):
     """Tests that batch_function op works with captured input."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
       captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@@ -323,7 +324,7 @@
 
   def testBatchFunctionOpWithInputError(self):
     """Tests that batch_function op works with error in the inputs."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
 
       @function.Defun(dtypes.int32, dtypes.int32)
@@ -346,7 +347,7 @@
 
   def testBasicUnbatchDecoratedWithReshape(self):
     """Tests that the batch_function decorator works."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
 
       @batch_ops.batch_function(1, 10, 100000)
       def computation(in_t):
@@ -368,7 +369,7 @@
 
   def testUnbatchTimeout(self):
     """Tests that the unbatch timeout works."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
       batched, index, id_t = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=2,
@@ -410,7 +411,7 @@
 
   def testUnbatchGrad(self):
     """Tests that batch and unbatch are differentiable."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
       batched, index, id_t = batch_ops.batch(
           [inp], num_batch_threads=1, max_batch_size=2,
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
index 9e6a146..13215ff 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
@@ -42,7 +42,7 @@
 
   def test_normal_integral_mean_and_var_correctly_estimated(self):
     n = int(1e6)
-    with self.test_session():
+    with self.cached_session():
       mu_p = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
       mu_q = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
       sigma_p = constant_op.constant([0.5, 0.5], dtype=dtypes.float64)
@@ -72,7 +72,7 @@
     # Test that importance sampling can correctly estimate the probability that
     # the product of components in a MultivariateNormal are > 0.
     n = 1000
-    with self.test_session():
+    with self.cached_session():
       p = mvn_diag_lib.MultivariateNormalDiag(
           loc=[0.], scale_diag=[1.0, 1.0])
       q = mvn_diag_lib.MultivariateNormalDiag(
@@ -99,7 +99,7 @@
   def test_normal_distribution_second_moment_estimated_correctly(self):
     # Test the importance sampled estimate against an analytical result.
     n = int(1e6)
-    with self.test_session():
+    with self.cached_session():
       mu_p = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
       mu_q = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
       sigma_p = constant_op.constant([1.0, 2 / 3.], dtype=dtypes.float64)
@@ -127,7 +127,7 @@
   """Test the private method 'get_samples'."""
 
   def test_raises_if_both_z_and_n_are_none(self):
-    with self.test_session():
+    with self.cached_session():
       dist = normal_lib.Normal(loc=0., scale=1.)
       z = None
       n = None
@@ -136,7 +136,7 @@
         _get_samples(dist, z, n, seed)
 
   def test_raises_if_both_z_and_n_are_not_none(self):
-    with self.test_session():
+    with self.cached_session():
       dist = normal_lib.Normal(loc=0., scale=1.)
       z = dist.sample(seed=42)
       n = 1
@@ -145,7 +145,7 @@
         _get_samples(dist, z, n, seed)
 
   def test_returns_n_samples_if_n_provided(self):
-    with self.test_session():
+    with self.cached_session():
       dist = normal_lib.Normal(loc=0., scale=1.)
       z = None
       n = 10
@@ -154,7 +154,7 @@
       self.assertEqual((10,), z.get_shape())
 
   def test_returns_z_if_z_provided(self):
-    with self.test_session():
+    with self.cached_session():
       dist = normal_lib.Normal(loc=0., scale=1.)
       z = dist.sample(10, seed=42)
       n = None
@@ -166,7 +166,7 @@
 class ExpectationTest(test.TestCase):
 
   def test_works_correctly(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = constant_op.constant([-1e6, -100, -10, -1, 1, 10, 100, 1e6])
       p = normal_lib.Normal(loc=x, scale=1.)
 
@@ -213,7 +213,7 @@
                           rtol=0.05, atol=0.)
 
   def test_docstring_example_normal(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_draws = int(1e5)
       mu_p = constant_op.constant(0.)
       mu_q = constant_op.constant(1.)
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
index 9afe3df..18d40fc 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
@@ -27,6 +27,7 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn
+from tensorflow.python.util import deprecation
 
 __all__ = [
     'expectation',
@@ -66,7 +67,7 @@
       shape broadcastable to `q.batch_shape`.
       For example, `log_p` works "just like" `sampling_dist_q.log_prob`.
     sampling_dist_q:  The sampling distribution.
-      `tf.contrib.distributions.Distribution`.
+      `tfp.distributions.Distribution`.
       `float64` `dtype` recommended.
       `log_p` and `q` should be supported on the same set.
     z:  `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
@@ -141,7 +142,7 @@
       shape broadcastable to `q.batch_shape`.
       For example, `log_p` works "just like" `q.log_prob`.
     sampling_dist_q:  The sampling distribution.
-      `tf.contrib.distributions.Distribution`.
+      `tfp.distributions.Distribution`.
       `float64` `dtype` recommended.
       `log_p` and `q` should be supported on the same set.
     z:  `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
@@ -188,6 +189,12 @@
   return log_mean_of_values
 
 
+@deprecation.deprecated(
+    '2018-10-01',
+    'The tf.contrib.bayesflow library has moved to '
+    'TensorFlow Probability (https://github.com/tensorflow/probability). '
+    'Use `tfp.monte_carlo.expectation` instead.',
+    warn_once=True)
 def expectation(f, samples, log_prob=None, use_reparametrization=True,
                 axis=0, keep_dims=False, name=None):
   r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\).
@@ -236,17 +243,17 @@
   Example Use:
 
   ```python
-  bf = tf.contrib.bayesflow
-  ds = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Monte-Carlo approximation of a reparameterized distribution, e.g., Normal.
 
   num_draws = int(1e5)
-  p = ds.Normal(loc=0., scale=1.)
-  q = ds.Normal(loc=1., scale=2.)
-  exact_kl_normal_normal = ds.kl_divergence(p, q)
+  p = tfd.Normal(loc=0., scale=1.)
+  q = tfd.Normal(loc=1., scale=2.)
+  exact_kl_normal_normal = tfd.kl_divergence(p, q)
   # ==> 0.44314718
-  approx_kl_normal_normal = bf.expectation(
+  approx_kl_normal_normal = tfp.monte_carlo.expectation(
       f=lambda x: p.log_prob(x) - q.log_prob(x),
       samples=p.sample(num_draws, seed=42),
       log_prob=p.log_prob,
@@ -260,9 +267,9 @@
   num_draws = int(1e5)
   p = ds.Gamma(concentration=1., rate=1.)
   q = ds.Gamma(concentration=2., rate=3.)
-  exact_kl_gamma_gamma = ds.kl_divergence(p, q)
+  exact_kl_gamma_gamma = tfd.kl_divergence(p, q)
   # ==> 0.37999129
-  approx_kl_gamma_gamma = bf.expectation(
+  approx_kl_gamma_gamma = tfp.monte_carlo.expectation(
       f=lambda x: p.log_prob(x) - q.log_prob(x),
       samples=p.sample(num_draws, seed=42),
       log_prob=p.log_prob,
@@ -278,7 +285,7 @@
   KL-divergence, the following is preferred:
 
   ```python
-  approx_kl_p_q = bf.monte_carlo_csiszar_f_divergence(
+  approx_kl_p_q = tfp.vi.monte_carlo_csiszar_f_divergence(
       f=bf.kl_reverse,
       p_log_prob=q.log_prob,
       q=p,
diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
index e36f7f3..316da9e 100644
--- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
+++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
@@ -61,7 +61,7 @@
     n = itr.get_next()
     expected = list(self.COMMON_ROW_KEYS)
     expected.reverse()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._writeCommonValues(sess)
       sess.run(itr.initializer)
       for i in range(3):
@@ -84,7 +84,7 @@
     expected_keys.reverse()
     expected_values = list(self.COMMON_VALUES)
     expected_values.reverse()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._writeCommonValues(sess)
       sess.run(itr.initializer)
       for i in range(3):
@@ -125,7 +125,7 @@
     expected_keys = list(self.COMMON_ROW_KEYS)
     expected_values = list(self.COMMON_VALUES)
     expected_tuples = zip(expected_keys, expected_values)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._writeCommonValues(sess)
       sess.run(itr.initializer)
       for i, elem in enumerate(expected_tuples):
@@ -144,7 +144,7 @@
     itr = ds.make_initializable_iterator()
     n = itr.get_next()
     expected_key = self.COMMON_ROW_KEYS[0]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._writeCommonValues(sess)
       sess.run(itr.initializer)
       output = sess.run(n)
@@ -163,7 +163,7 @@
   def runSampleKeyPairsTest(self, ds, expected_key_pairs):
     itr = ds.make_initializable_iterator()
     n = itr.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._writeCommonValues(sess)
       sess.run(itr.initializer)
       for i, elems in enumerate(expected_key_pairs):
@@ -219,7 +219,7 @@
     ds = bigtable_api._BigtableSampleKeyPairsDataset(
         self._table, prefix="r", start="r1", end="")
     itr = ds.make_initializable_iterator()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaises(errors.InvalidArgumentError):
         sess.run(itr.initializer)
 
@@ -227,7 +227,7 @@
     ds = bigtable_api._BigtableSampleKeyPairsDataset(
         self._table, prefix="r", start="", end="r3")
     itr = ds.make_initializable_iterator()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaises(errors.InvalidArgumentError):
         sess.run(itr.initializer)
 
@@ -235,7 +235,7 @@
     ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1")
     itr = ds.make_initializable_iterator()
     n = itr.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._writeCommonValues(sess)
       sess.run(itr.initializer)
       expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
@@ -253,7 +253,7 @@
     ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1")
     itr = ds.make_initializable_iterator()
     n = itr.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._writeCommonValues(sess)
       sess.run(itr.initializer)
       expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
index 5fcb19a..14b6fc4 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
@@ -173,6 +173,7 @@
 py_test(
     name = "dnn_tree_combined_estimator_test",
     size = "medium",
+    timeout = "long",
     srcs = ["dnn_tree_combined_estimator_test.py"],
     srcs_version = "PY2AND3",
     tags = [
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
index 78232fa..48f12a6 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
@@ -51,6 +51,7 @@
     feature_columns: A list of feature columns.
     export_input_fn: A function that takes no arguments and returns an
       `InputFnOps`.
+    use_core_columns: A boolean, whether core feature columns were used.
 
   Returns:
     An `ExportStrategy`.
@@ -196,7 +197,7 @@
           matching_id.int64_value = split.feature_id
           node.custom_left_child_test.Pack(categorical_test)
         else:
-          raise ValueError("Unexpected node type %s", node_type)
+          raise ValueError("Unexpected node type %s" % node_type)
         node.left_child_id.value = split.left_id
         node.right_child_id.value = split.right_id
   return model_and_features
@@ -236,7 +237,7 @@
         assert tree_node.node_metadata.gain == 0
         continue
       else:
-        raise ValueError("Unexpected split type %s", node_type)
+        raise ValueError("Unexpected split type %s" % node_type)
       # Apply shrinkage factor. It is important since it is not always uniform
       # across different trees.
       sums[split_column] += (
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 51e0c2e..af7006b 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -579,13 +579,6 @@
         const int end_index =
             partition_boundaries[non_empty_partitions[root_idx]][j + 1]
                 .start_index;
-        CHECK(bucket_ids_and_dimensions(start_index, 1) ==
-              bucket_ids_and_dimensions(end_index - 1, 1))
-            << "For bucket " << bucket_ids_and_dimensions(start_index, 0)
-            << " the dimension was "
-            << bucket_ids_and_dimensions(start_index, 1) << " and for "
-            << bucket_ids_and_dimensions(end_index - 1, 0) << " "
-            << bucket_ids_and_dimensions(end_index - 1, 1);
         if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) {
           // 0-dimension case which has a first bucket for catch all feature.
           CHECK(bucket_ids_and_dimensions(start_index, 1) == 0)
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index 73e41bc4..9d9941f 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -86,7 +86,7 @@
 
   def testExtractFeatures(self):
     """Tests feature extraction."""
-    with self.test_session():
+    with self.cached_session():
       features = {}
       features["dense_float"] = array_ops.zeros([2, 1], dtypes.float32)
       features["sparse_float"] = sparse_tensor.SparseTensor(
@@ -128,7 +128,7 @@
 
   def testExtractFeaturesWithTransformation(self):
     """Tests feature extraction."""
-    with self.test_session():
+    with self.cached_session():
       features = {}
       features["dense_float"] = array_ops.zeros([2, 1], dtypes.float32)
       features["sparse_float"] = sparse_tensor.SparseTensor(
@@ -178,7 +178,7 @@
 
   def testExtractFeaturesFromCoreFeatureColumns(self):
     """Tests feature extraction when using core columns."""
-    with self.test_session():
+    with self.cached_session():
       features = {}
       # Sparse float column does not exist in core, so only dense numeric and
       # categorical.
@@ -213,7 +213,7 @@
 
   def testTrainFnChiefNoBiasCentering(self):
     """Tests the train function running on chief without bias centering."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
       learner_config = learner_pb2.LearnerConfig()
@@ -316,7 +316,7 @@
       self.assertProtoEquals(expected_tree, output.trees[0])
 
   def testObliviousDecisionTreeAsWeakLearner(self):
-    with self.test_session():
+    with self.cached_session():
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
       learner_config = learner_pb2.LearnerConfig()
@@ -473,7 +473,7 @@
 
   def testTrainFnChiefSparseAndDense(self):
     """Tests the train function with sparse and dense features."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
       learner_config = learner_pb2.LearnerConfig()
@@ -580,7 +580,7 @@
 
   def testTrainFnChiefScalingNumberOfExamples(self):
     """Tests the train function running on chief without bias centering."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
       learner_config = learner_pb2.LearnerConfig()
@@ -685,7 +685,7 @@
 
   def testTrainFnChiefWithBiasCentering(self):
     """Tests the train function running on chief with bias centering."""
-    with self.test_session():
+    with self.cached_session():
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
       learner_config = learner_pb2.LearnerConfig()
@@ -757,7 +757,7 @@
 
   def testTrainFnNonChiefNoBiasCentering(self):
     """Tests the train function running on worker without bias centering."""
-    with self.test_session():
+    with self.cached_session():
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
       learner_config = learner_pb2.LearnerConfig()
@@ -821,7 +821,7 @@
 
   def testTrainFnNonChiefWithCentering(self):
     """Tests the train function running on worker with bias centering."""
-    with self.test_session():
+    with self.cached_session():
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
       learner_config = learner_pb2.LearnerConfig()
@@ -885,7 +885,7 @@
 
   def testPredictFn(self):
     """Tests the predict function."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create ensemble with one bias node.
       ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
       text_format.Merge(
@@ -939,7 +939,7 @@
 
   def testPredictFnWithLeafIndexAdvancedLeft(self):
     """Tests the predict function with output leaf ids."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create ensemble with one bias node.
       ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
       text_format.Merge(
@@ -1051,7 +1051,7 @@
 
   def testTrainFnMulticlassFullHessian(self):
     """Tests the GBDT train for multiclass full hessian."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
 
@@ -1155,7 +1155,7 @@
 
   def testTrainFnMulticlassDiagonalHessian(self):
     """Tests the GBDT train for multiclass diagonal hessian."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
 
@@ -1259,7 +1259,7 @@
 
   def testTrainFnMulticlassTreePerClass(self):
     """Tests the GBDT train for multiclass tree per class strategy."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
 
@@ -1374,7 +1374,7 @@
 
   def testTrainFnChiefFeatureSelectionReachedLimitNoGoodSplit(self):
     """Tests the train function running on chief with feature selection."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
       learner_config = learner_pb2.LearnerConfig()
@@ -1493,7 +1493,7 @@
 
   def testTrainFnChiefFeatureSelectionWithGoodSplits(self):
     """Tests the train function running on chief with feature selection."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
       learner_config = learner_pb2.LearnerConfig()
@@ -1610,7 +1610,7 @@
 
   def testTrainFnChiefFeatureSelectionReachedLimitIncrementAttemptedLayer(self):
     """Tests the train function running on chief with feature selection."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
       tree = tree_ensemble_config.trees.add()
 
@@ -1720,7 +1720,7 @@
 
   def testResetModelBeforeAndAfterSplit(self):
     """Tests whether resetting works."""
-    with self.test_session():
+    with self.cached_session():
       # First build a small tree and train it to verify training works.
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
@@ -1854,7 +1854,7 @@
 
   def testResetModelNonChief(self):
     """Tests the reset function on a non-chief worker."""
-    with self.test_session():
+    with self.cached_session():
       # Create ensemble with one bias node.
       ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
       text_format.Merge(
@@ -1930,7 +1930,7 @@
 
   def testResetModelWithCenterBias(self):
     """Tests the reset function running on chief with bias centering."""
-    with self.test_session():
+    with self.cached_session():
       ensemble_handle = model_ops.tree_ensemble_variable(
           stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
       learner_config = learner_pb2.LearnerConfig()
diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py
index ccb8509..cc22504 100644
--- a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py
+++ b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py
@@ -45,7 +45,7 @@
 
     eps = 0.2
 
-    with self.test_session():
+    with self.cached_session():
       predictions_tensor = constant_op.constant(
           prediction_logits, dtype=dtypes.float32)
       loss_for_positives, _ = losses.per_example_exp_loss(
@@ -84,7 +84,7 @@
     predictions = np.array(
         [[0.123], [23.2], [233], [52], [3]], dtype=np.float32)
 
-    with self.test_session():
+    with self.cached_session():
       loss_tensor, _ = losses.per_example_squared_loss(labels, weights,
                                                        predictions)
 
diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md
index 789dab8..77242b3 100644
--- a/tensorflow/contrib/cmake/README.md
+++ b/tensorflow/contrib/cmake/README.md
@@ -17,7 +17,7 @@
 Current Status
 --------------
 
-CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/install/install_windows)
+CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/install/source_windows)
 for instructions on how to install a pre-built TensorFlow package on Windows.
 
 ### Current known limitations
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index fb871ac..1c432b6 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -273,9 +273,6 @@
 tensorflow/contrib/libsvm/python
 tensorflow/contrib/libsvm/python/kernel_tests
 tensorflow/contrib/libsvm/python/ops
-tensorflow/contrib/linalg
-tensorflow/contrib/linalg/python
-tensorflow/contrib/linalg/python/ops
 tensorflow/contrib/linear_optimizer
 tensorflow/contrib/linear_optimizer/kernels
 tensorflow/contrib/linear_optimizer/kernels/g3doc
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 2c878c1..ed31351 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -183,7 +183,6 @@
     file(GLOB_RECURSE tf_test_src_py
       ${tf_test_src_py}
       "${tensorflow_source_dir}/tensorflow/contrib/legacy_seq2seq/*_test.py"
-      "${tensorflow_source_dir}/tensorflow/contrib/linalg/*_test.py"
       "${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py"
       "${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py"
       "${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py"
diff --git a/tensorflow/contrib/coder/python/ops/coder_ops_test.py b/tensorflow/contrib/coder/python/ops/coder_ops_test.py
index d5e14e7..f5431ca 100644
--- a/tensorflow/contrib/coder/python/ops/coder_ops_test.py
+++ b/tensorflow/contrib/coder/python/ops/coder_ops_test.py
@@ -45,7 +45,7 @@
     decoded = coder_ops.range_decode(
         encoded, array_ops.shape(data), cdf, precision=14)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual(*sess.run((data, decoded)))
 
 
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index d7583be..3b0e8f6 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -53,11 +53,14 @@
     srcs = ["xla.py"],
     srcs_version = "PY2AND3",
     deps = [
+        "//tensorflow/compiler/jit:xla_ops_py",
+        "//tensorflow/contrib/tpu:tpu_lib",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:platform",
         "//tensorflow/python:util",
+        "//tensorflow/python:variable_scope",
         "//tensorflow/python/estimator:model_fn",
     ],
 )
diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py
index 42b3b9f..3e631b5 100644
--- a/tensorflow/contrib/compiler/jit_test.py
+++ b/tensorflow/contrib/compiler/jit_test.py
@@ -173,7 +173,7 @@
 class CompilationEnabledInGradientTest(test.TestCase):
 
   def testCompilationInGradient(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant([[3.]])
       y_nc = math_ops.matmul(x, x, name="not_compiled")
       with jit.experimental_jit_scope():
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
index 60f5af1..0aae695 100644
--- a/tensorflow/contrib/compiler/xla.py
+++ b/tensorflow/contrib/compiler/xla.py
@@ -12,18 +12,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # =============================================================================
-"""xla provides experimental xla support API."""
+"""xla is an experimental library that provides XLA support APIs."""
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
+from tensorflow.compiler.jit.ops import xla_ops
+from tensorflow.contrib.tpu.python.tpu import tpu_function
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variable_scope
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import compat
 
@@ -51,6 +55,30 @@
 ])
 
 
+def compile(computation, inputs=None):  # pylint: disable=redefined-builtin
+  """Builds an operator that compiles and runs `computation` with XLA.
+
+  Args:
+    computation: A Python function that builds a computation to apply to the
+      input. If the function takes n inputs, 'inputs' should be a list of n
+      tensors.
+
+      `computation` may return a list of operations and tensors.  Tensors must
+      come before operations in the returned list.  The return value of
+      `compile` is a list of tensors corresponding to the tensors from the
+      output of `computation`.
+
+      All `Operation`s returned from `computation` will be executed when
+      evaluating any of the returned output tensors.
+    inputs: A list of input tensors or `None` (equivalent to an empty list).
+
+  Returns:
+    A list of output tensors.
+  """
+  # pylint: disable=protected-access
+  return _compile_internal(computation, inputs)
+
+
 class XLACompileContext(control_flow_ops.XLAControlFlowContext):
   """A `ControlFlowContext` for nodes inside an XLA computation cluster.
 
@@ -206,3 +234,122 @@
     if self.GetWhileContext():
       return self.GetWhileContext().back_prop
     return False
+
+
+def _compile_internal(computation, inputs=None):
+  """Builds graph operators that compiles and symbolically executes computation.
+
+  Args:
+    computation: A Python function that builds the computation to compile and
+      execute.
+    inputs: A list of input tensors or `None` (equivalent to `[]`). Its order
+      should match ordering of computation arguments.
+  Returns:
+    A list of output tensors from computation.
+  Raises:
+    ValueError: If any element in computation outputs is neither an operations
+      or a value that can be converted to tensor.
+    TypeError: If `inputs` is not a list or tuple.
+  """
+  if inputs is None:
+    inputs = []
+
+  if not isinstance(inputs, collections.Sequence):
+    raise TypeError('inputs must be a list')
+
+  # Converts inputs to Tensors.
+  inputs = [ops.convert_to_tensor(x) for x in inputs]
+  input_arity = len(inputs)
+
+  arg_error = tpu_function.check_function_argument_count(
+      computation, input_arity, infeed_queue=None)
+  if arg_error is not None:
+    raise TypeError(
+        'Supplied computation cannot be called with the specified inputs. You '
+        'specified %d inputs: %s, but the computation needs %s' %
+        (input_arity, str([i.name for i in inputs[0]]), arg_error))
+
+  cluster_name = ops.get_default_graph().unique_name('cluster')
+  pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
+  context = XLACompileContext(name=cluster_name, pivot=pivot)
+  try:
+    context.Enter()
+
+    # Add identity ops so even unused inputs are 'consumed' by the
+    # computation.
+    computation_inputs = [
+        array_ops.identity(x, name='input_{}'.format(i))
+        for i, x in enumerate(inputs)
+    ]
+
+    # Only resource variables work inside an XLA computation, so turn on
+    # resource variables for the computation.
+    vscope = variable_scope.get_variable_scope()
+    saved_use_resource = vscope.use_resource
+    vscope.set_use_resource(True)
+
+    outputs = computation(*computation_inputs)
+
+    # Restore variable scope after computation.
+    vscope.set_use_resource(saved_use_resource)
+
+    # If the computation returns `None`, make it an empty tuple.
+    if outputs is None:
+      outputs = tuple()
+    # If the computation only returned one value, make it a tuple.
+    if not isinstance(outputs, collections.Sequence):
+      outputs = (outputs,)
+
+    # Append `no_op` here so that return value of this function always contains
+    # at least one op that can trigger XlaLaunch node.
+    outputs += (control_flow_ops.no_op(),)
+    try:
+      outputs = [
+          o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
+          for o in outputs
+      ]
+    except Exception as e:
+      raise ValueError(
+          'XLA computation function return values must all either be Operations'
+          ' or convertible to Tensors. Got error: "%s"' % str(e))
+
+    # Separates the returned Operations and Tensors.
+    output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
+    output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
+
+    if outputs != output_tensors + output_operations:
+      raise ValueError(
+          'XLA computation function must return zero or more Tensor values '
+          'followed by zero or more Operations.')
+    output_arity = len(output_tensors)
+
+    new_output_tensors = []
+    for t in output_tensors:
+      with ops.device(t.device if t.device else ''):
+        new_output_tensors.append(array_ops.identity(t))
+
+    output_tensors = new_output_tensors
+    context.ExitResult(output_tensors)
+  finally:
+    context.report_unsupported_operations()
+    context.Exit()
+
+  outputs = [
+      xla_ops.xla_cluster_output(output_tensors[i], name='output{}'.format(i))
+      for i in xrange(output_arity)
+  ]
+
+  with ops.control_dependencies(output_operations):
+    if output_arity == 0:
+      # When XLA computation returns only operations and no tensors, a NoOp
+      # dependent on the operations in outputs is returned. Otherwise final
+      # outputs would be empty and there is no way to trigger returned
+      # operations.
+      return control_flow_ops.no_op(name='output_0')
+    else:
+      # Wraps the outputs in identity operators that carries control
+      # dependencies.
+      return [
+          array_ops.identity(outputs[i], name='output_%d' % i)
+          for i in xrange(output_arity)
+      ]
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
index 5a66748..c59d368 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
@@ -413,6 +413,31 @@
         self._testOneLSTMParamsSize(num_layers, num_units, input_size,
                                     direction)
 
+  @unittest.skipUnless(test.is_built_with_cuda(),
+                       "Test only applicable when running on GPUs")
+  def testLSTMParamsSizeShape(self):
+    with self.assertRaisesRegexp(
+        ValueError, "Shape must be rank 0 but is rank 1"):
+      model = _CreateModel(
+          cudnn_rnn_ops.CUDNN_LSTM,
+          constant_op.constant([4]), 200, 200,
+          direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+      params_size = model.params_size()
+    with self.assertRaisesRegexp(
+        ValueError, "Shape must be rank 0 but is rank 1"):
+      model = _CreateModel(
+          cudnn_rnn_ops.CUDNN_LSTM,
+          4, constant_op.constant([200]), 200,
+          direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+      params_size = model.params_size()
+    with self.assertRaisesRegexp(
+        ValueError, "Shape must be rank 0 but is rank 1"):
+      model = _CreateModel(
+          cudnn_rnn_ops.CUDNN_LSTM,
+          4, 200, constant_op.constant([200]),
+          direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+      params_size = model.params_size()
+
 
 class CudnnRNNTestInference(TensorFlowTestCase):
 
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index fda1b9f..57793a8 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -460,7 +460,7 @@
       grad, = gradients.gradients(
           math_ops.reduce_sum(accumulation), (original_input,))
     init_op = variables.global_variables_initializer()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       accumulation_eval, grad_eval = sess.run((accumulation, grad))
       self.assertAllEqual([28, 100, 100], accumulation_eval.shape)
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index baec238..c378b1c 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -62,6 +62,8 @@
 @@sloppy_interleave
 @@unbatch
 @@unique
+
+@@AUTOTUNE
 """
 
 from __future__ import absolute_import
@@ -91,6 +93,10 @@
 from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
 from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
 from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
+
+# Optimization constant that can be used to enable auto-tuning.
+from tensorflow.contrib.data.python.ops.optimization import AUTOTUNE
+
 from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset
 from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
 from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
@@ -113,6 +119,3 @@
 
 from tensorflow.python.util.all_util import remove_undocumented
 remove_undocumented(__name__)
-
-# A constant that can be used to enable auto-tuning.
-AUTOTUNE = -1
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 8e368bf..e2508de 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -742,7 +742,7 @@
     iterator = result.make_initializable_iterator()
     init_op = iterator.initializer
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for _ in range(5):
         sess.run(get_next)
@@ -813,7 +813,7 @@
         .make_initializable_iterator())
     init_op = iterator.initializer
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       with self.assertRaises(errors.InvalidArgumentError):
         sess.run(get_next)
@@ -837,7 +837,7 @@
     iterator = result.make_initializable_iterator()
     init_op = iterator.initializer
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for _ in range(5):
         sess.run(get_next)
@@ -879,7 +879,7 @@
     iterator = result.make_initializable_iterator()
     init_op = iterator.initializer
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for _ in range(5):
         sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 83b7237..25aea03 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -116,7 +116,7 @@
     elems2 = array_ops.placeholder(dtypes.int32)
     result = map_defun.map_defun(fn, [elems1, elems2],
                                  [dtypes.int32, dtypes.int32], [(), ()])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaisesWithPredicateMatch(
           errors.InvalidArgumentError,
           "All inputs must have the same dimension 0."):
@@ -225,7 +225,7 @@
     c = constant_op.constant([1, 2, 3, 4, 5])
     map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       thread = self.checkedThread(
           self._assert_op_cancelled, args=(sess, map_defun_op))
       thread.start()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
index 7e9ea68..b3187bf 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -74,6 +74,23 @@
 )
 
 py_test(
+    name = "map_parallelization_test",
+    size = "small",
+    srcs = ["map_parallelization_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/contrib/data/python/ops:optimization",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:errors",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
+py_test(
     name = "model_dataset_op_test",
     size = "medium",
     srcs = ["model_dataset_op_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
index bd7b50b..d10da80 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -31,7 +31,7 @@
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(0, sess.run(get_next))
 
   def testAssertNextInvalid(self):
@@ -40,7 +40,7 @@
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
           "Asserted Whoops transformation at offset 0 but encountered "
@@ -53,7 +53,7 @@
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
           "Asserted next 2 transformations but encountered only 1."):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index dde1159..e75edf6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -200,7 +200,7 @@
         optimization.optimize(["filter_fusion"]))
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for x in range(5):
         r = map_function(x)
         filtered = False
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
new file mode 100644
index 0000000..dd547db
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
@@ -0,0 +1,84 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapParallelization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class MapParallelizationTest(test.TestCase, parameterized.TestCase):
+
+  @staticmethod
+  def map_functions():
+    identity = lambda x: x
+    increment = lambda x: x + 1
+
+    def assert_greater(x):
+      assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
+      with ops.control_dependencies([assert_op]):
+        return x
+
+    def random(_):
+      return random_ops.random_uniform([],
+                                       minval=0,
+                                       maxval=10,
+                                       dtype=dtypes.int64,
+                                       seed=42)
+
+    def assert_with_random(x):
+      x = assert_greater(x)
+      return random(x)
+
+    return (("Identity", identity, True), ("Increment", increment, True),
+            ("AssertGreater", assert_greater, True), ("Random", random, False),
+            ("AssertWithRandom", assert_with_random, False))
+
+  @parameterized.named_parameters(*map_functions.__func__())
+  def testMapParallelization(self, function, should_optimize):
+    next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
+    dataset = dataset_ops.Dataset.range(5).apply(
+        optimization.assert_next(next_nodes)).map(function).apply(
+            optimization.optimize(["map_parallelization"]))
+    iterator = dataset.make_one_shot_iterator()
+    get_next = iterator.get_next()
+
+    with self.test_session() as sess:
+      for x in range(5):
+        result = sess.run(get_next)
+        # No need to run the pipeline if it was not optimized.  Also the results
+        # might be hard to check because of random.
+        if not should_optimize:
+          return
+        r = function(x)
+        self.assertAllEqual(r, result)
+
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
index e2c9bc8..5b493f4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
@@ -173,16 +173,6 @@
     self.report_benchmark(iters=num_iters, wall_time=median_time, name=name)
     return median_time
 
-  def benchmark_CheapFns(self):
-
-    input_sizes = [(10, 10, 3), (10, 100, 300)]
-    batch_size = 1000
-    for input_size in input_sizes:
-      input_dataset = dataset_ops.Dataset.from_tensor_slices(
-          (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
-      for map_fn, str_id in self._get_known_cheap_fns():
-        self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
-
   def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
     num_elems = np.prod(input_size)
     name_template = "{}__batch_size_{}_input_size_{}_{}"
@@ -205,14 +195,28 @@
           "Speedup: {}\n".format(batch_size, input_size, str_id,
                                  (unoptimized_time / optimized_time)))
 
-  def _get_known_cheap_fns(self):
-    return [
-        (lambda *args: [array_ops.identity(x) for x in args], "identity"),
-        (lambda *args: [x + 1 for x in args], "add_const"),
-        (lambda *args: args[0], "select"),
-        (lambda *args: [math_ops.cast(x, dtypes.float64) for x in args],
-         "cast"),
-    ]
+  # Known cheap functions
+  def benchmarkIdentity(self):
+    self._benchmark_helper(lambda *args: [array_ops.identity(x) for x in args],
+                           "identity")
+
+  def benchmarkAddConst(self):
+    self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const")
+
+  def benchmarkSelect(self):
+    self._benchmark_helper(lambda *args: args[0], "select")
+
+  def benchmarkCast(self):
+    self._benchmark_helper(
+        lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], "cast")
+
+  def _benchmark_helper(self, map_fn, str_id):
+    input_sizes = [(10, 10, 3), (10, 100, 300)]
+    batch_size = 1000
+    for input_size in input_sizes:
+      input_dataset = dataset_ops.Dataset.from_tensor_slices(
+          (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
+      self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
index 0a87d3e..3b62a7e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
@@ -40,7 +40,7 @@
     get_next = iterator.get_next()
 
     deltas = []
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(5):
         sess.run(get_next.op)
       for _ in range(100):
@@ -58,12 +58,13 @@
     dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
                                                 np.random.rand(4 * k,
                                                                1))).repeat()
-    dataset = dataset.map(math_ops.matmul, num_parallel_calls=56)
+    dataset = dataset.map(
+        math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE)
     iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
     get_next = iterator.get_next()
 
     deltas = []
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(5):
         sess.run(get_next.op)
       for _ in range(1000):
@@ -84,12 +85,14 @@
                                                                1))).repeat()
     dataset = dataset.apply(
         batching.map_and_batch(
-            math_ops.matmul, num_parallel_calls=28, batch_size=batch_size))
+            math_ops.matmul,
+            num_parallel_calls=optimization.AUTOTUNE,
+            batch_size=batch_size))
     iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
     get_next = iterator.get_next()
 
     deltas = []
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(5):
         sess.run(get_next.op)
       for _ in range(10):
@@ -109,12 +112,14 @@
                                                                1))).repeat()
     dataset = dataset.map(math_ops.matmul)
     dataset = dataset_ops.Dataset.range(1).repeat().interleave(
-        lambda _: dataset, cycle_length=56, num_parallel_calls=56)
+        lambda _: dataset,
+        cycle_length=10,
+        num_parallel_calls=optimization.AUTOTUNE)
     iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
     get_next = iterator.get_next()
 
     deltas = []
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(5):
         sess.run(get_next.op)
       for _ in range(1000):
@@ -146,20 +151,20 @@
       x, y = c
       return a, b, math_ops.matmul(x, y)
 
-    dataset = dataset.map(f1, num_parallel_calls=32)
+    dataset = dataset.map(f1, num_parallel_calls=optimization.AUTOTUNE)
     dataset = dataset_ops.Dataset.range(1).repeat().interleave(
         lambda _: dataset, cycle_length=2)
 
-    dataset = dataset.map(f2, num_parallel_calls=16)
+    dataset = dataset.map(f2, num_parallel_calls=optimization.AUTOTUNE)
     dataset = dataset_ops.Dataset.range(1).repeat().interleave(
         lambda _: dataset, cycle_length=2)
 
-    dataset = dataset.map(f3, num_parallel_calls=10)
+    dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE)
     iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
     get_next = iterator.get_next()
 
     deltas = []
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(5):
         sess.run(get_next)
       for _ in range(100):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
index 909da5a..a3fb824 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -38,7 +38,7 @@
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
@@ -51,7 +51,7 @@
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
@@ -64,7 +64,7 @@
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
@@ -76,7 +76,7 @@
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(get_next)
 
   def testOptimizationLargeInputFromTensor(self):
@@ -87,7 +87,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
       sess.run(get_next)
 
@@ -99,7 +99,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
       sess.run(get_next)
 
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index e25570c..be8ae5e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -25,6 +25,7 @@
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
@@ -40,7 +41,7 @@
     next_element = iterator.get_next()
     summary_t = stats_aggregator.get_summary()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       expected_sum = 0.0
       for i in range(100):
@@ -65,7 +66,7 @@
     next_element = iterator.get_next()
     summary_t = stats_aggregator.get_summary()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       for i in range(100):
         self.assertEqual(i, sess.run(next_element))
@@ -84,7 +85,7 @@
     next_element = iterator.get_next()
     summary_t = stats_aggregator.get_summary()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       for i in range(100):
         self.assertAllEqual(
@@ -92,6 +93,8 @@
         summary_str = sess.run(summary_t)
         self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
                                     float(i + 1))
+        self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
+        self._assertSummaryContains(summary_str, "Prefetch::buffer_size")
         self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
                                     0, 1)
       with self.assertRaises(errors.OutOfRangeError):
@@ -100,6 +103,53 @@
       self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
                                   100)
 
+  def testPrefetchBufferScalars(self):
+    stats_aggregator = stats_ops.StatsAggregator()
+    dataset = dataset_ops.Dataset.range(10).map(
+        lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+            0).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+    iterator = dataset.make_initializable_iterator()
+    next_element = iterator.get_next()
+    summary_t = stats_aggregator.get_summary()
+
+    with self.cached_session() as sess:
+      sess.run(iterator.initializer)
+      for i in range(10):
+        self.assertAllEqual(
+            np.array([i] * i, dtype=np.int64), sess.run(next_element))
+        summary_str = sess.run(summary_t)
+        self._assertSummaryHasScalarValue(summary_str,
+                                          "Prefetch::buffer_capacity", 0)
+        self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
+                                          0)
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(next_element)
+
+  def testFilteredElementsStats(self):
+    stats_aggregator = stats_ops.StatsAggregator()
+    dataset = dataset_ops.Dataset.range(101).filter(
+        lambda x: math_ops.equal(math_ops.mod(x, 3), 0)).apply(
+            stats_ops.set_stats_aggregator(stats_aggregator))
+    iterator = dataset.make_initializable_iterator()
+    next_element = iterator.get_next()
+    summary_t = stats_aggregator.get_summary()
+
+    with self.test_session() as sess:
+      sess.run(iterator.initializer)
+      for i in range(34):
+        self.assertEqual(i * 3, sess.run(next_element))
+        if i is not 0:
+          self._assertSummaryHasScalarValue(
+              sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
+        self._assertSummaryHasScalarValue(
+            sess.run(summary_t), "Filter::filtered_elements", float(i + 1))
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(next_element)
+      self._assertSummaryHasScalarValue(
+          sess.run(summary_t), "Filter::dropped_elements", 67.0)
+      self._assertSummaryHasScalarValue(
+          sess.run(summary_t), "Filter::filtered_elements", 34.0)
+
   def testReinitialize(self):
     stats_aggregator = stats_ops.StatsAggregator()
     dataset = dataset_ops.Dataset.range(100).apply(
@@ -109,7 +159,7 @@
     next_element = iterator.get_next()
     summary_t = stats_aggregator.get_summary()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for j in range(5):
         sess.run(iterator.initializer)
         for i in range(100):
@@ -127,7 +177,7 @@
     iterator = dataset.make_initializable_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       for i in range(100):
         self.assertEqual(i, sess.run(next_element))
@@ -144,7 +194,7 @@
     next_element = iterator.get_next()
     summary_t = stats_aggregator.get_summary()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       for i in range(100):
         self.assertEqual(i, sess.run(next_element))
@@ -168,7 +218,7 @@
     next_element = iterator.get_next()
     summary_t = stats_aggregator.get_summary()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       for i in range(100):
         self.assertEqual(i, sess.run(next_element))
@@ -188,7 +238,7 @@
     next_element = iterator_0.get_next() + iterator_1.get_next()
     summary_t = stats_aggregator.get_summary()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run([iterator_0.initializer, iterator_1.initializer])
       for i in range(100):
         self.assertEqual(i * 2, sess.run(next_element))
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index 2f5a444..b1b4c23 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -25,6 +25,14 @@
 class StatsDatasetTestBase(test.TestCase):
   """Base class for testing statistics gathered in `StatsAggregator`."""
 
+  def _assertSummaryContains(self, summary_str, tag):
+    summary_proto = summary_pb2.Summary()
+    summary_proto.ParseFromString(summary_str)
+    for value in summary_proto.value:
+      if tag == value.tag:
+        return
+    self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
   def _assertSummaryHasCount(self, summary_str, tag, expected_value):
     summary_proto = summary_pb2.Summary()
     summary_proto.ParseFromString(summary_str)
@@ -52,3 +60,12 @@
         self.assertEqual(expected_value, value.histo.sum)
         return
     self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+  def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value):
+    summary_proto = summary_pb2.Summary()
+    summary_proto.ParseFromString(summary_str)
+    for value in summary_proto.value:
+      if tag == value.tag:
+        self.assertEqual(expected_value, value.simple_value)
+        return
+    self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 6eaa0b19..8b7b3ac 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -89,13 +89,14 @@
       return dataset_ops.Dataset.zip(
           tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args]))
 
-    dataset = self._structuredDataset(structure, shape, dtype).apply(
+    dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
         grouping.window_dataset(5)).flat_map(fn)
     get_next = dataset.make_one_shot_iterator().get_next()
     with self.cached_session() as sess:
       expected = sess.run(self._structuredElement(structure, shape, dtype))
-      actual = sess.run(get_next)
-      self._assertEqual(expected, actual)
+      for _ in range(5):
+        actual = sess.run(get_next)
+        self._assertEqual(expected, actual)
 
   @parameterized.named_parameters(
       ("1", None, np.int32([]), dtypes.bool),
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 4b45cc7..a14781c 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -80,6 +80,7 @@
         ":batching",
         ":gen_dataset_ops",
         ":interleave_ops",
+        ":optimization",
         ":parsing_ops",
         ":shuffle_ops",
         "//tensorflow/python:constant_op",
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 099e10d..020167e 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -255,6 +255,7 @@
   return _apply_fn
 
 
+# TODO(b/115382007) Remove this once canned reducers move to core.
 def window_dataset(window_size):
   """A transformation that creates window datasets from the input dataset.
 
@@ -271,7 +272,12 @@
   """
 
   def _apply_fn(dataset):
-    return _WindowDataset(dataset, window_size)
+    return dataset_ops.WindowDataset(
+        dataset,
+        size=window_size,
+        shift=window_size,
+        stride=1,
+        drop_remainder=False)
 
   return _apply_fn
 
@@ -556,46 +562,3 @@
   @property
   def output_types(self):
     return self._output_types
-
-
-class _WindowDataset(dataset_ops.Dataset):
-  """A dataset that creates window datasets from the input elements."""
-
-  def __init__(self, input_dataset, window_size):
-    """See `window_dataset()` for more details."""
-    super(_WindowDataset, self).__init__()
-    self._input_dataset = input_dataset
-    self._window_size = ops.convert_to_tensor(
-        window_size, dtype=dtypes.int64, name="window_size")
-    self._output_classes = nest.pack_sequence_as(
-        input_dataset.output_classes,
-        [
-            dataset_ops._NestedDatasetComponent(  # pylint: disable=protected-access
-                output_classes=output_class,
-                output_shapes=output_shape,
-                output_types=output_type)
-            for output_class, output_shape, output_type in zip(
-                nest.flatten(input_dataset.output_classes),
-                nest.flatten(input_dataset.output_shapes),
-                nest.flatten(input_dataset.output_types))
-        ])
-    self._output_shapes = self._output_classes
-    self._output_types = self._output_classes
-
-  def _as_variant_tensor(self):
-    return gen_dataset_ops.window_dataset(
-        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
-        self._window_size,
-        **dataset_ops.flat_structure(self))
-
-  @property
-  def output_classes(self):
-    return self._output_classes
-
-  @property
-  def output_shapes(self):
-    return self._output_shapes
-
-  @property
-  def output_types(self):
-    return self._output_types
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index 4114b62..7384045 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -24,6 +24,9 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import gen_dataset_ops
 
+# A constant that can be used to enable auto-tuning.
+AUTOTUNE = -1
+
 
 # TODO(jsimsa): Support RE matching for both individual transformation (e.g. to
 # account for indexing) and transformation sequence.
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 4c46678..785b395 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -25,6 +25,7 @@
 from tensorflow.contrib.data.python.ops import batching
 from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
 from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.contrib.data.python.ops import optimization
 from tensorflow.contrib.data.python.ops import parsing_ops
 from tensorflow.contrib.data.python.ops import shuffle_ops
 from tensorflow.python.data.ops import dataset_ops
@@ -214,18 +215,17 @@
   return dataset
 
 
-def make_tf_record_dataset(
-    file_pattern,
-    batch_size,
-    parser_fn=None,
-    num_epochs=None,
-    shuffle=True,
-    shuffle_buffer_size=None,
-    shuffle_seed=None,
-    prefetch_buffer_size=None,
-    num_parallel_reads=None,
-    num_parallel_parser_calls=None,
-    drop_final_batch=False):
+def make_tf_record_dataset(file_pattern,
+                           batch_size,
+                           parser_fn=None,
+                           num_epochs=None,
+                           shuffle=True,
+                           shuffle_buffer_size=None,
+                           shuffle_seed=None,
+                           prefetch_buffer_size=optimization.AUTOTUNE,
+                           num_parallel_reads=None,
+                           num_parallel_parser_calls=None,
+                           drop_final_batch=False):
   """Reads and optionally parses TFRecord files into a dataset.
 
   Provides common functionality such as batching, optional parsing, shuffling,
@@ -300,8 +300,6 @@
         parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
         drop_remainder=drop_final_batch))
 
-  if prefetch_buffer_size is None:
-    prefetch_buffer_size = -1  # tf.config.data.AUTOTUNE
   if prefetch_buffer_size == 0:
     return dataset
   else:
@@ -323,7 +321,7 @@
     shuffle=True,
     shuffle_buffer_size=10000,
     shuffle_seed=None,
-    prefetch_buffer_size=1,
+    prefetch_buffer_size=optimization.AUTOTUNE,
     num_parallel_reads=1,
     sloppy=False,
     num_rows_for_inference=100,
@@ -386,9 +384,10 @@
     shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
       ensures better shuffling, but increases memory usage and startup time.
     shuffle_seed: Randomization seed to use for shuffling.
-    prefetch_buffer_size: An int specifying the number of feature batches to
-      prefetch for performance improvement. Recommended value is the number of
-      batches consumed per training step.
+    prefetch_buffer_size: An int specifying the number of feature
+      batches to prefetch for performance improvement. Recommended value is the
+      number of batches consumed per training step. Defaults to auto-tune.
+
     num_parallel_reads: Number of threads used to read CSV records from files.
       If >1, the results will be interleaved.
     sloppy: If `True`, reading performance will be improved at
@@ -666,7 +665,7 @@
                                   shuffle=True,
                                   shuffle_buffer_size=10000,
                                   shuffle_seed=None,
-                                  prefetch_buffer_size=1,
+                                  prefetch_buffer_size=optimization.AUTOTUNE,
                                   reader_num_threads=1,
                                   parser_num_threads=2,
                                   sloppy_ordering=False,
@@ -739,7 +738,7 @@
     shuffle_seed: Randomization seed to use for shuffling.
     prefetch_buffer_size: Number of feature batches to prefetch in order to
       improve performance. Recommended value is the number of batches consumed
-      per training step (default is 1).
+      per training step. Defaults to auto-tune.
     reader_num_threads: Number of threads used to read `Example` records. If >1,
       the results will be interleaved.
     parser_num_threads: Number of threads to use for parsing `Example` tensors
diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py
index 8025dcd..b0d6a16 100644
--- a/tensorflow/contrib/data/python/ops/sliding.py
+++ b/tensorflow/contrib/data/python/ops/sliding.py
@@ -67,6 +67,10 @@
 
 @deprecation.deprecated_args(
     None, "stride is deprecated, use window_shift instead", "stride")
+@deprecation.deprecated(
+    None, "Use `tf.data.Dataset.window(size=window_size, shift=window_shift, "
+    "stride=window_stride).flat_map(lambda x: x.batch(window.size))` "
+    "instead.")
 def sliding_window_batch(window_size,
                          stride=None,
                          window_shift=None,
diff --git a/tensorflow/contrib/deprecated/summaries_test.py b/tensorflow/contrib/deprecated/summaries_test.py
index 6acf2a6..4038224 100644
--- a/tensorflow/contrib/deprecated/summaries_test.py
+++ b/tensorflow/contrib/deprecated/summaries_test.py
@@ -27,31 +27,31 @@
 class DeprecatedSummariesTest(test.TestCase):
 
   def testScalarSummary(self):
-    with self.test_session():
+    with self.cached_session():
       c = constant_op.constant(3)
       s = logging_ops.scalar_summary('tag', c)
       self.assertEqual(s.op.type, u'ScalarSummary')
 
   def testHistogramSummary(self):
-    with self.test_session():
+    with self.cached_session():
       c = constant_op.constant(3)
       s = logging_ops.histogram_summary('tag', c)
       self.assertEqual(s.op.type, u'HistogramSummary')
 
   def testImageSummary(self):
-    with self.test_session():
+    with self.cached_session():
       i = array_ops.ones((5, 4, 4, 3))
       s = logging_ops.image_summary('tag', i)
       self.assertEqual(s.op.type, u'ImageSummary')
 
   def testAudioSummary(self):
-    with self.test_session():
+    with self.cached_session():
       c = constant_op.constant(3.0)
       s = logging_ops.audio_summary('tag', c, sample_rate=8000)
       self.assertEqual(s.op.type, u'AudioSummaryV2')
 
   def testMergeSummary(self):
-    with self.test_session():
+    with self.cached_session():
       c = constant_op.constant(3)
       a = logging_ops.scalar_summary('a', c)
       b = logging_ops.scalar_summary('b', c)
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
index 350f81f..823fe6a 100644
--- a/tensorflow/contrib/distribute/__init__.py
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -12,7 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Prototype of a distributed computation library for TF."""
+"""A distributed computation library for TF.
+
+See [tensorflow/contrib/distribute/README.md](
+https://www.tensorflow.org/code/tensorflow/contrib/distribute/README.md)
+for overview and examples.
+"""
 
 from __future__ import absolute_import
 from __future__ import division
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 87f76ea..48a7593 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -472,11 +472,8 @@
         "//tensorflow/python:summary",
     ],
     tags = [
-        "manual",
         "multi_and_single_gpu",
         "no_pip",
-        "nogpu",
-        "notap",
     ],
 )
 
@@ -485,7 +482,6 @@
     srcs = ["single_loss_example.py"],
     deps = [
         ":step_fn",
-        "//tensorflow/contrib/data/python/ops:batching",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:constant_op",
         "//tensorflow/python:layers",
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index 4903714..a3e1b96 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -114,7 +114,7 @@
         self.assertEqual([v.numpy() for v in left._index.values()],
                          list(right._index.values()))
       else:
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           self.assertEqual(
               sess.run(list(left._index.values())), list(right._index.values()))
 
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py
index 24cb08f..9fc1b88 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_utils.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py
@@ -221,9 +221,12 @@
   return small_grads, large_grads
 
 
-# threading.Lock() cannot be pickled and therefore cannot be a field of
-# CollectiveKeys.
+# threading.Lock() and threading.local() cannot be pickled and therefore cannot
+# be a field of CollectiveKeys. Right now _thread_local is not necessary to be
+# an instance member of CollectiveKeys since we always create a new thread for
+# each tower.
 _lock = threading.Lock()
+_thread_local = threading.local()
 
 
 # TODO(yuefengz): use random key starts to avoid reusing keys?
@@ -266,14 +269,12 @@
     # For instance keys without ids
     self._instance_key_start = instance_key_start
 
-    self._thread_local = threading.local()
-
   def _get_thread_local_object(self):
     # We make instance key without key ids thread local so that it will work
     # with MirroredStrategy and distribute coordinator.
-    if not hasattr(self._thread_local, 'instance_key'):
-      self._thread_local.instance_key = self._instance_key_start
-    return self._thread_local
+    if not hasattr(_thread_local, 'instance_key'):
+      _thread_local.instance_key = self._instance_key_start
+    return _thread_local
 
   def get_group_key(self, devices):
     """Returns a group key for the set of devices.
diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py
index 5348512..157618f 100644
--- a/tensorflow/contrib/distribute/python/estimator_training_test.py
+++ b/tensorflow/contrib/distribute/python/estimator_training_test.py
@@ -26,21 +26,12 @@
 import threading
 from absl.testing import parameterized
 import numpy as np
-import six
 
-_portpicker_import_error = None
-try:
-  import portpicker  # pylint: disable=g-import-not-at-top
-except ImportError as _error:  # pylint: disable=invalid-name
-  _portpicker_import_error = _error
-  portpicker = None
-
-# pylint: disable=g-import-not-at-top
 from tensorflow.contrib.distribute.python import combinations
 from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import multi_worker_test_base
 from tensorflow.contrib.distribute.python import parameter_server_strategy
 from tensorflow.contrib.optimizer_v2 import adagrad
-from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import distribute_coordinator as dc
 from tensorflow.python.distribute import estimator_training as dc_training
@@ -57,7 +48,6 @@
 from tensorflow.python.platform import test
 from tensorflow.python.summary import summary_iterator
 from tensorflow.python.summary.writer import writer_cache
-from tensorflow.python.training import server_lib
 
 BATCH_SIZE = 10
 LABEL_DIMENSION = 2
@@ -73,130 +63,38 @@
 WORKER = dc._TaskType.WORKER
 PS = dc._TaskType.PS
 
-original_run_distribute_coordinator = dc.run_distribute_coordinator
+original_run_std_server = dc._run_std_server
 
 
-# TODO(yuefengz): merge this method back to test_util.
-def _create_local_cluster(num_workers,
-                          num_ps,
-                          has_eval=False,
-                          protocol="grpc",
-                          worker_config=None,
-                          ps_config=None):
-  if _portpicker_import_error:
-    raise _portpicker_import_error  # pylint: disable=raising-bad-type
-  worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
-  ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+class MockOsEnv(dict):
 
-  cluster_dict = {
-      "worker": ["localhost:%s" % port for port in worker_ports],
-      "ps": ["localhost:%s" % port for port in ps_ports]
-  }
-  if has_eval:
-    cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()]
+  def __init__(self, *args):
+    self._thread_local = threading.local()
+    super(MockOsEnv, self).__init__(*args)
 
-  cs = server_lib.ClusterSpec(cluster_dict)
+  def get(self, key, default):
+    if not hasattr(self._thread_local, "dict"):
+      self._thread_local.dict = dict()
+    if key == "TF_CONFIG":
+      return dict.get(self._thread_local.dict, key, default)
+    else:
+      return dict.get(self, key, default)
 
-  workers = [
-      server_lib.Server(
-          cs,
-          job_name="worker",
-          protocol=protocol,
-          task_index=ix,
-          config=worker_config,
-          start=True) for ix in range(num_workers)
-  ]
-  ps_servers = [
-      server_lib.Server(
-          cs,
-          job_name="ps",
-          protocol=protocol,
-          task_index=ix,
-          config=ps_config,
-          start=True) for ix in range(num_ps)
-  ]
-  if has_eval:
-    evals = [
-        server_lib.Server(
-            cs,
-            job_name="evaluator",
-            protocol=protocol,
-            task_index=0,
-            config=worker_config,
-            start=True)
-    ]
-  else:
-    evals = []
+  def __getitem__(self, key):
+    if not hasattr(self._thread_local, "dict"):
+      self._thread_local.dict = dict()
+    if key == "TF_CONFIG":
+      return dict.__getitem__(self._thread_local.dict, key)
+    else:
+      return dict.__getitem__(self, key)
 
-  return workers, ps_servers, evals
-
-
-def _create_in_process_cluster(num_workers, num_ps, has_eval=False):
-  """Create an in-process cluster that consists of only standard server."""
-  # Leave some memory for cuda runtime.
-  if has_eval:
-    gpu_mem_frac = 0.7 / (num_workers + 1)
-  else:
-    gpu_mem_frac = 0.7 / num_workers
-
-  worker_config = config_pb2.ConfigProto()
-  worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
-
-  # Enable collective ops which has no impact on non-collective ops.
-  # TODO(yuefengz, tucker): removing this after we move the initialization of
-  # collective mgr to the session level.
-  worker_config.experimental.collective_group_leader = (
-      "/job:worker/replica:0/task:0")
-
-  ps_config = config_pb2.ConfigProto()
-  ps_config.device_count["GPU"] = 0
-
-  return _create_local_cluster(
-      num_workers,
-      num_ps=num_ps,
-      has_eval=has_eval,
-      worker_config=worker_config,
-      ps_config=ps_config,
-      protocol="grpc")
-
-
-def _create_cluster_spec(has_chief=False,
-                         num_workers=1,
-                         num_ps=0,
-                         has_eval=False):
-  if _portpicker_import_error:
-    raise _portpicker_import_error  # pylint: disable=raising-bad-type
-
-  cluster_spec = {}
-  if has_chief:
-    cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()]
-  if num_workers:
-    cluster_spec[WORKER] = [
-        "localhost:%s" % portpicker.pick_unused_port()
-        for _ in range(num_workers)
-    ]
-  if num_ps:
-    cluster_spec[PS] = [
-        "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps)
-    ]
-  if has_eval:
-    cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
-  return cluster_spec
-
-
-def _bytes_to_str(maybe_bytes):
-  if isinstance(maybe_bytes, six.string_types):
-    return maybe_bytes
-  else:
-    return str(maybe_bytes, "utf-8")
-
-
-def _strip_protocol(target):
-  # cluster_spec expects "host:port" strings.
-  if "//" in target:
-    return target.split("//")[1]
-  else:
-    return target
+  def __setitem__(self, key, val):
+    if not hasattr(self._thread_local, "dict"):
+      self._thread_local.dict = dict()
+    if key == "TF_CONFIG":
+      return dict.__setitem__(self._thread_local.dict, key, val)
+    else:
+      return dict.__setitem__(self, key, val)
 
 
 class DistributeCoordinatorIntegrationTest(test.TestCase,
@@ -205,22 +103,20 @@
   @classmethod
   def setUpClass(cls):
     """Create a local cluster with 2 workers."""
-    cls._workers, cls._ps, cls._evals = _create_in_process_cluster(
+    cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
         num_workers=3, num_ps=2, has_eval=True)
-    cls._cluster_spec = {
-        "worker": [
-            _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
-        ],
-        "ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps],
-        "evaluator": [
-            _strip_protocol(_bytes_to_str(e.target)) for e in cls._evals
-        ]
-    }
 
   def setUp(self):
     self._model_dir = tempfile.mkdtemp()
-    self._event = threading.Event()
+    self._mock_os_env = MockOsEnv()
+    self._mock_context = test.mock.patch.object(os, "environ",
+                                                self._mock_os_env)
     super(DistributeCoordinatorIntegrationTest, self).setUp()
+    self._mock_context.__enter__()
+
+  def tearDown(self):
+    self._mock_context.__exit__(None, None, None)
+    super(DistributeCoordinatorIntegrationTest, self).tearDown()
 
   def dataset_input_fn(self, x, y, batch_size, shuffle):
 
@@ -391,43 +287,17 @@
         train_distribute, eval_distribute, remote_cluster=self._cluster_spec)
     self._inspect_train_and_eval_events(estimator)
 
-  def _mock_run_distribute_coordinator(
-      self,
-      worker_fn,
-      strategy,
-      eval_fn,
-      eval_strategy,
-      mode=dc.CoordinatorMode.STANDALONE_CLIENT,
-      cluster_spec=None,
-      session_config=None):
-    # Calls the origial `run_distribute_coordinator` method but gets task config
-    # from environment variables and then signals the caller.
-    task_type = None
-    task_id = None
-    if not cluster_spec:
-      cluster_spec = None
-      tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
-      if not cluster_spec:
-        cluster_spec = tf_config.get("cluster", {})
-        task_env = tf_config.get("task", {})
-        if task_env:
-          task_type = task_env.get("type", task_type)
-          task_id = int(task_env.get("index", task_id))
-    self._event.set()
-    original_run_distribute_coordinator(
-        worker_fn,
-        strategy,
-        eval_fn,
-        eval_strategy,
-        mode=mode,
-        cluster_spec=cluster_spec,
-        task_type=task_type,
-        task_id=task_id,
-        session_config=session_config)
+  def _mock_run_std_server(self, *args, **kwargs):
+    ret = original_run_std_server(*args, **kwargs)
+    # Wait for all std servers to be brought up in order to reduce the chance of
+    # remote sessions taking local ports that have been assigned to std servers.
+    self._barrier.wait()
+    return ret
 
-  def _task_thread(self, train_distribute, eval_distribute):
-    with test.mock.patch.object(dc, "run_distribute_coordinator",
-                                self._mock_run_distribute_coordinator):
+  def _task_thread(self, train_distribute, eval_distribute, tf_config):
+    os.environ["TF_CONFIG"] = json.dumps(tf_config)
+    with test.mock.patch.object(dc, "_run_std_server",
+                                self._mock_run_std_server):
       self._complete_flow(train_distribute, eval_distribute)
 
   def _run_task_in_thread(self, cluster_spec, task_type, task_id,
@@ -448,13 +318,10 @@
               "index": task_id
           }
       }
-    self._event.clear()
     t = threading.Thread(
-        target=self._task_thread, args=(train_distribute, eval_distribute))
-    with test.mock.patch.dict("os.environ",
-                              {"TF_CONFIG": json.dumps(tf_config)}):
-      t.start()
-      self._event.wait()
+        target=self._task_thread,
+        args=(train_distribute, eval_distribute, tf_config))
+    t.start()
     return t
 
   def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute,
@@ -489,7 +356,11 @@
     else:
       eval_distribute = None
 
-    cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+    cluster_spec = multi_worker_test_base.create_cluster_spec(
+        num_workers=3, num_ps=2, has_eval=True)
+    # 3 workers, 2 ps and 1 evaluator.
+    self._barrier = dc._Barrier(6)
+
     threads = self._run_multiple_tasks_in_threads(
         cluster_spec, train_distribute, eval_distribute)
     for task_type, ts in threads.items():
@@ -516,7 +387,10 @@
     else:
       eval_distribute = None
 
-    cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+    cluster_spec = multi_worker_test_base.create_cluster_spec(
+        num_workers=3, num_ps=0, has_eval=True)
+    # 3 workers and 1 evaluator.
+    self._barrier = dc._Barrier(4)
     threads = self._run_multiple_tasks_in_threads(
         cluster_spec, train_distribute, eval_distribute)
     threads[WORKER][0].join()
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 5f35e38..8165a70 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -732,14 +732,22 @@
     with self.cached_session():
       keras.backend.set_image_data_format('channels_last')
       num_samples = 10000
+
+      # Train and predict datasets are created with the same input numpy arrays.
       x_train = np.random.rand(num_samples, 1)
       y_train = 3 * x_train
       x_train = x_train.astype('float32')
       y_train = y_train.astype('float32')
 
+      # The model is built once and the initial weights are saved.
+      # This is used to initialize the model for both the distribution and
+      # non-distribution run.
+      model = keras.Sequential()
+      model.add(keras.layers.Dense(1, input_shape=(1,)))
+      initial_weights = model.get_weights()
+
       def fit_and_predict(with_distribution=None):
-        model = keras.Sequential()
-        model.add(keras.layers.Dense(1, input_shape=(1,)))
+        model.set_weights(initial_weights)
         model.compile(
             loss=keras.losses.mean_squared_error,
             optimizer=gradient_descent.GradientDescentOptimizer(0.5),
@@ -751,12 +759,14 @@
         train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train,
                                                                 y_train))
         train_dataset = batch_wrapper(train_dataset, batch_size, distribution)
-        # Running only 100 steps instead of the full dataset to keep test
-        # duration small.
-        model.fit(x=train_dataset, epochs=1, steps_per_epoch=100)
+        # We have initialized the model to the same weight for the distribution
+        # and non-distribution run. If you want to initialize the model to
+        # random weights for each run, you need to run the model through the
+        # entire dataset at least once to ensure that the weights converge to
+        # the same value.
+        model.fit(x=train_dataset, epochs=1, steps_per_epoch=10)
 
         weights = model.get_weights()
-
         x_predict = [[1.], [2.], [3.], [4.]]
         predict_batch_size = 4
         if with_distribution:
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index c6894e9..f51e543 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -1271,7 +1271,17 @@
                             self.evaluate(device_result))
 
       for defun in defuns:
-        self.assertEqual(set(mock_model.variables), set(defun.variables))
+        # PolymorphicFunctions are specialized to the current device stack, so
+        # call_for_each has one trace per device. To check that the expected set
+        # of variables was accessed on each trace, we first retrieve each
+        # device-specific graph function.
+        per_device_graph_functions = dist.call_for_each_tower(
+            defun.get_concrete_function,
+            mock_model, *inputs, run_concurrently=False)
+        for device in devices:
+          graph_function = per_device_graph_functions.get(device=device)
+          self.assertEqual(set(mock_model.variables),
+                           set(graph_function.graph.variables))
 
   @test_util.run_in_graph_and_eager_modes()
   def testVariableInDefun(self):
diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
index 18b4503..9f92ba7 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
@@ -36,9 +36,29 @@
 from tensorflow.python.client import session
 from tensorflow.python.estimator import run_config
 from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import server_lib
 
 
+ASSIGNED_PORTS = set()
+lock = threading.Lock()
+
+
+def pick_unused_port():
+  """Returns an unused and unassigned local port."""
+  if _portpicker_import_error:
+    raise _portpicker_import_error  # pylint: disable=raising-bad-type
+
+  global ASSIGNED_PORTS
+  with lock:
+    while True:
+      port = portpicker.pick_unused_port()
+      if port > 10000 and port not in ASSIGNED_PORTS:
+        ASSIGNED_PORTS.add(port)
+        logging.info('Using local port %r', port)
+        return port
+
+
 def _create_cluster(num_workers,
                     num_ps,
                     has_chief=False,
@@ -49,8 +69,8 @@
   """Creates and starts local servers and returns the cluster_spec dict."""
   if _portpicker_import_error:
     raise _portpicker_import_error  # pylint: disable=raising-bad-type
-  worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
-  ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+  worker_ports = [pick_unused_port() for _ in range(num_workers)]
+  ps_ports = [pick_unused_port() for _ in range(num_ps)]
 
   cluster_dict = {}
   if num_workers > 0:
@@ -58,9 +78,9 @@
   if num_ps > 0:
     cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
   if has_eval:
-    cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()]
+    cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()]
   if has_chief:
-    cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()]
+    cluster_dict['chief'] = ['localhost:%s' % pick_unused_port()]
 
   cs = server_lib.ClusterSpec(cluster_dict)
 
@@ -139,11 +159,36 @@
       num_workers,
       num_ps=num_ps,
       has_chief=has_chief,
+      has_eval=has_eval,
       worker_config=worker_config,
       ps_config=ps_config,
       protocol='grpc')
 
 
+def create_cluster_spec(has_chief=False,
+                        num_workers=1,
+                        num_ps=0,
+                        has_eval=False):
+  """Create a cluster spec with tasks with unused local ports."""
+  if _portpicker_import_error:
+    raise _portpicker_import_error  # pylint: disable=raising-bad-type
+
+  cluster_spec = {}
+  if has_chief:
+    cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()]
+  if num_workers:
+    cluster_spec['worker'] = [
+        'localhost:%s' % pick_unused_port() for _ in range(num_workers)
+    ]
+  if num_ps:
+    cluster_spec['ps'] = [
+        'localhost:%s' % pick_unused_port() for _ in range(num_ps)
+    ]
+  if has_eval:
+    cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()]
+  return cluster_spec
+
+
 class MultiWorkerTestBase(test.TestCase):
   """Base class for testing multi node strategy and dataset."""
 
diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py
index 5aa19cf..09b351f 100644
--- a/tensorflow/contrib/distribute/python/single_loss_example.py
+++ b/tensorflow/contrib/distribute/python/single_loss_example.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.data.python.ops import batching
 from tensorflow.contrib.distribute.python import step_fn
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import constant_op
@@ -59,10 +58,9 @@
 
   def dataset_fn():
     dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
-    # TODO(isaprykin): map_and_batch with drop_remainder causes shapes to be
+    # TODO(isaprykin): batch with drop_remainder causes shapes to be
     # fully defined for TPU.  Remove this when XLA supports dynamic shapes.
-    return dataset.apply(
-        batching.map_and_batch(lambda x: x, batch_size=1, drop_remainder=True))
+    return dataset.batch(1, drop_remainder=True)
 
   # An Optimizer instance is created either outside or inside model_fn.
   outer_optimizer = None
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 9aadc63..3ff7da4 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -25,7 +25,6 @@
                    "`tf.contrib.distributions` to `tfp.distributions`."),
     srcs_version = "PY2AND3",
     deps = [
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:check_ops",
         "//tensorflow/python:clip_ops",
@@ -61,7 +60,6 @@
         ":bijectors_py",
         "//tensorflow/contrib/framework:framework_py",
         "//tensorflow/contrib/learn",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:check_ops",
         "//tensorflow/python:control_flow_ops",
@@ -706,8 +704,8 @@
         ":bijectors_py",
         ":distributions_py",
         "//third_party/py/numpy",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
@@ -722,8 +720,8 @@
     additional_deps = [
         ":distributions_py",
         "//third_party/py/numpy",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/ops/linalg",
     ],
     shard_count = 4,
     tags = ["noasan"],  # times out, http://b/78588814
@@ -739,8 +737,8 @@
     additional_deps = [
         ":distributions_py",
         "//third_party/py/numpy",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
@@ -794,8 +792,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -831,8 +829,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -852,8 +850,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -871,8 +869,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -907,8 +905,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -926,10 +924,10 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
@@ -945,8 +943,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -964,8 +962,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -983,8 +981,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1002,8 +1000,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1021,8 +1019,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1040,8 +1038,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1075,8 +1073,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1126,8 +1124,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1161,8 +1159,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1180,8 +1178,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1201,8 +1199,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1221,8 +1219,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1240,8 +1238,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1259,8 +1257,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1278,8 +1276,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1297,8 +1295,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
@@ -1316,8 +1314,8 @@
         ":distributions_py",
         "//third_party/py/numpy",
         "@six_archive//:six",
-        "//tensorflow/contrib/linalg:linalg_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python/ops/linalg",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:framework_test_lib",
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
index 8dad80a..c32ea9a 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
@@ -93,12 +93,12 @@
             bijector.inverse_log_det_jacobian(y, event_ndims=1)))
 
   def testScalarCongruency(self):
-    with self.test_session():
+    with self.cached_session():
       bijector = Softsign(validate_args=True)
       assert_scalar_congruency(bijector, lower_x=-20., upper_x=20.)
 
   def testBijectiveAndFinite(self):
-    with self.test_session():
+    with self.cached_session():
       bijector = Softsign(validate_args=True)
       x = np.linspace(-20., 20., 100).astype(np.float32)
       y = np.linspace(-0.99, 0.99, 100).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
index f073f51..9b9b3ce 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
@@ -212,7 +212,7 @@
   def testStrWorksCorrectlyScalar(self):
     normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
     self.assertEqual(
-        ("tf.distributions.Normal("
+        ("tfp.distributions.Normal("
          "\"Normal/\", "
          "batch_shape=(), "
          "event_shape=(), "
@@ -221,7 +221,7 @@
 
     chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
     self.assertEqual(
-        ("tf.distributions.Chi2("
+        ("tfp.distributions.Chi2("
          "\"silly/\", "  # What a silly name that is!
          "batch_shape=(2,), "
          "event_shape=(), "
@@ -230,7 +230,7 @@
 
     exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
     self.assertEqual(
-        ("tf.distributions.Exponential(\"Exponential/\", "
+        ("tfp.distributions.Exponential(\"Exponential/\", "
          # No batch shape.
          "event_shape=(), "
          "dtype=float32)"),
@@ -240,7 +240,7 @@
     mvn_static = tfd.MultivariateNormalDiag(
         loc=np.zeros([2, 2]), name="MVN")
     self.assertEqual(
-        ("tf.distributions.MultivariateNormalDiag("
+        ("tfp.distributions.MultivariateNormalDiag("
          "\"MVN/\", "
          "batch_shape=(2,), "
          "event_shape=(2,), "
@@ -251,7 +251,7 @@
         loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32),
         name="MVN2")
     self.assertEqual(
-        ("tf.distributions.MultivariateNormalDiag("
+        ("tfp.distributions.MultivariateNormalDiag("
          "\"MVN2/\", "
          "batch_shape=(?,), "  # Partially known.
          "event_shape=(3,), "
@@ -261,7 +261,7 @@
   def testReprWorksCorrectlyScalar(self):
     normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
     self.assertEqual(
-        ("<tf.distributions.Normal"
+        ("<tfp.distributions.Normal"
          " 'Normal/'"
          " batch_shape=()"
          " event_shape=()"
@@ -270,7 +270,7 @@
 
     chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
     self.assertEqual(
-        ("<tf.distributions.Chi2"
+        ("<tfp.distributions.Chi2"
          " 'silly/'"  # What a silly name that is!
          " batch_shape=(2,)"
          " event_shape=()"
@@ -279,7 +279,7 @@
 
     exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
     self.assertEqual(
-        ("<tf.distributions.Exponential"
+        ("<tfp.distributions.Exponential"
          " 'Exponential/'"
          " batch_shape=<unknown>"
          " event_shape=()"
@@ -290,7 +290,7 @@
     mvn_static = tfd.MultivariateNormalDiag(
         loc=np.zeros([2, 2]), name="MVN")
     self.assertEqual(
-        ("<tf.distributions.MultivariateNormalDiag"
+        ("<tfp.distributions.MultivariateNormalDiag"
          " 'MVN/'"
          " batch_shape=(2,)"
          " event_shape=(2,)"
@@ -301,7 +301,7 @@
         loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32),
         name="MVN2")
     self.assertEqual(
-        ("<tf.distributions.MultivariateNormalDiag"
+        ("<tfp.distributions.MultivariateNormalDiag"
          " 'MVN2/'"
          " batch_shape=(?,)"  # Partially known.
          " event_shape=(3,)"
diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py
index bb9b804..3ba1c3a 100644
--- a/tensorflow/contrib/distributions/python/ops/autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py
@@ -65,13 +65,14 @@
   ```
 
   where the ellipses (`...`) represent `n-2` composed calls to `fn`, `fn`
-  constructs a `tf.distributions.Distribution`-like instance, and `x0` is a
+  constructs a `tfp.distributions.Distribution`-like instance, and `x0` is a
   fixed initializing `Tensor`.
 
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   def normal_fn(self, event_size):
     n = event_size * (event_size + 1) / 2
@@ -127,7 +128,7 @@
 
     Args:
       distribution_fn: Python `callable` which constructs a
-        `tf.distributions.Distribution`-like instance from a `Tensor` (e.g.,
+        `tfp.distributions.Distribution`-like instance from a `Tensor` (e.g.,
         `sample0`). The function must respect the "autoregressive property",
         i.e., there exists a permutation of event such that each coordinate is a
         diffeomorphic function of on preceding coordinates.
diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
index 519077b..612376e 100644
--- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
@@ -45,7 +45,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   dtype = np.float32
   dims = 2
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
index 296e66f..3b3d8ee 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
@@ -61,8 +61,8 @@
   `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves
   this property by zeroing out weights in its `masked_dense` layers.
 
-  In the `tf.distributions` framework, a "normalizing flow" is implemented as a
-  `tf.contrib.distributions.bijectors.Bijector`. The `forward` "autoregression"
+  In the `tfp` framework, a "normalizing flow" is implemented as a
+  `tfp.bijectors.Bijector`. The `forward` "autoregression"
   is implemented using a `tf.while_loop` and a deep neural network (DNN) with
   masked weights such that the autoregressive property is automatically met in
   the `inverse`.
@@ -126,8 +126,9 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
-  tfb = tfd.bijectors
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+  tfb = tfp.bijectors
 
   dims = 5
 
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
index f182a1a..178c3c9 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
@@ -41,9 +41,10 @@
   """Permutes the rightmost dimension of a `Tensor`.
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfb = tfp.bijectors
 
-  reverse = tfd.bijectors.Permute(permutation=[2, 1, 0])
+  reverse = tfb.Permute(permutation=[2, 1, 0])
 
   reverse.forward([-1., 0., 1.])
   # ==> [1., 0., -1]
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
index 773ae24..0bcb08c 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
@@ -90,8 +90,9 @@
   #### Example Use
 
   ```python
-  tfd = tf.contrib.distributions
-  tfb = tfd.bijectors
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+  tfb = tfp.bijectors
 
   # A common choice for a normalizing flow is to use a Gaussian for the base
   # distribution. (However, any continuous distribution would work.) E.g.,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
index c828222..71ac290 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
@@ -80,9 +80,10 @@
   Example usage:
   ```python
 
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfb = tfp.bijectors
 
-  r = tfd.bijectors.Reshape(event_shape_out=[1, -1])
+  r = tfb.Reshape(event_shape_out=[1, -1])
 
   r.forward([3., 4.])    # shape [2]
   # ==> [[3., 4.]]       # shape [1, 2]
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
index 6fbe866..0a6d690 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
@@ -42,7 +42,10 @@
   #### Examples
 
   ```python
-  tfb = tf.contrib.distributions.bijectors
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+  tfb = tfp.bijectors
+
   b = tfb.ScaleTriL(
        diag_bijector=tfb.Exp(),
        diag_shift=None)
diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py
index cb5223b..c461833 100644
--- a/tensorflow/contrib/distributions/python/ops/cauchy.py
+++ b/tensorflow/contrib/distributions/python/ops/cauchy.py
@@ -63,7 +63,8 @@
   Examples of initialization of one or a batch of distributions.
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Define a single scalar Cauchy distribution.
   dist = tfd.Cauchy(loc=0., scale=3.)
diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py
index affc64a..507c5d3 100644
--- a/tensorflow/contrib/distributions/python/ops/deterministic.py
+++ b/tensorflow/contrib/distributions/python/ops/deterministic.py
@@ -198,8 +198,11 @@
   #### Examples
 
   ```python
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+
   # Initialize a single Deterministic supported at zero.
-  constant = tf.contrib.distributions.Deterministic(0.)
+  constant = tfd.Deterministic(0.)
   constant.prob(0.)
   ==> 1.
   constant.prob(2.)
@@ -208,7 +211,7 @@
   # Initialize a [2, 2] batch of scalar constants.
   loc = [[0., 1.], [2., 3.]]
   x = [[0., 1.1], [1.99, 3.]]
-  constant = tf.contrib.distributions.Deterministic(loc)
+  constant = tfd.Deterministic(loc)
   constant.prob(x)
   ==> [[1., 0.], [0., 1.]]
   ```
@@ -310,7 +313,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Initialize a single VectorDeterministic supported at [0., 2.] in R^2.
   constant = tfd.Deterministic([0., 2.])
diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py
index acdea4d..4b50df5 100644
--- a/tensorflow/contrib/distributions/python/ops/gumbel.py
+++ b/tensorflow/contrib/distributions/python/ops/gumbel.py
@@ -63,7 +63,8 @@
   Examples of initialization of one or a batch of distributions.
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Define a single scalar Gumbel distribution.
   dist = tfd.Gumbel(loc=0., scale=3.)
diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py
index b02c403..f121637 100644
--- a/tensorflow/contrib/distributions/python/ops/half_normal.py
+++ b/tensorflow/contrib/distributions/python/ops/half_normal.py
@@ -66,15 +66,18 @@
   Examples of initialization of one or a batch of distributions.
 
   ```python
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+
   # Define a single scalar HalfNormal distribution.
-  dist = tf.contrib.distributions.HalfNormal(scale=3.0)
+  dist = tfd.HalfNormal(scale=3.0)
 
   # Evaluate the cdf at 1, returning a scalar.
   dist.cdf(1.)
 
   # Define a batch of two scalar valued HalfNormals.
   # The first has scale 11.0, the second 22.0
-  dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0])
+  dist = tfd.HalfNormal(scale=[11.0, 22.0])
 
   # Evaluate the pdf of the first distribution on 1.0, and the second on 1.5,
   # returning a length two tensor.
diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py
index 0672702..e1cfff3 100644
--- a/tensorflow/contrib/distributions/python/ops/independent.py
+++ b/tensorflow/contrib/distributions/python/ops/independent.py
@@ -70,7 +70,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Make independent distribution from a 2-batch Normal.
   ind = tfd.Independent(
diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
index 70d050d..4526282 100644
--- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
@@ -89,7 +89,9 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+
   dist = tfd.InverseGamma(concentration=3.0, rate=2.0)
   dist2 = tfd.InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
   ```
diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py
index 02e3bad..21c9b5a 100644
--- a/tensorflow/contrib/distributions/python/ops/logistic.py
+++ b/tensorflow/contrib/distributions/python/ops/logistic.py
@@ -61,7 +61,8 @@
   Examples of initialization of one or a batch of distributions.
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Define a single scalar Logistic distribution.
   dist = tfd.Logistic(loc=0., scale=3.)
diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py
index 3b7114e..52b67f2 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture.py
@@ -50,7 +50,9 @@
 
   ```python
   # Create a mixture of two Gaussians:
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+
   mix = 0.3
   bimix_gauss = tfd.Mixture(
     cat=tfd.Categorical(probs=[mix, 1.-mix]),
diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
index 8ffee94..f4d394f 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
@@ -44,7 +44,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   ### Create a mixture of two scalar Gaussians:
 
@@ -113,12 +114,12 @@
     """Construct a `MixtureSameFamily` distribution.
 
     Args:
-      mixture_distribution: `tf.distributions.Categorical`-like instance.
+      mixture_distribution: `tfp.distributions.Categorical`-like instance.
         Manages the probability of selecting components. The number of
         categories must match the rightmost batch dimension of the
         `components_distribution`. Must have either scalar `batch_shape` or
         `batch_shape` matching `components_distribution.batch_shape[:-1]`.
-      components_distribution: `tf.distributions.Distribution`-like instance.
+      components_distribution: `tfp.distributions.Distribution`-like instance.
         Right-most batch dimension indexes components.
       validate_args: Python `bool`, default `False`. When `True` distribution
         parameters are checked for validity despite possibly degrading runtime
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
index cd0c282..0b5b76b 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
@@ -85,7 +85,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Initialize a single 2-variate Gaussian.
   mvn = tfd.MultivariateNormalDiag(
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
index 74d9d04..8054608 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
@@ -87,7 +87,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`,
   # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
index dbc4c1b..bcb4937 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
@@ -73,7 +73,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Initialize a single 3-variate Gaussian.
   mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
index efe5a6d..8fdc998 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
@@ -91,7 +91,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Initialize a single 3-variate Gaussian.
   mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
index c6a23e4..c21f70f 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
@@ -77,13 +77,14 @@
   ```
 
   Trainable (batch) lower-triangular matrices can be created with
-  `tf.contrib.distributions.matrix_diag_transform()` and/or
-  `tf.contrib.distributions.fill_triangular()`
+  `tfp.distributions.matrix_diag_transform()` and/or
+  `tfp.distributions.fill_triangular()`
 
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Initialize a single 3-variate Gaussian.
   mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
index 7a7ad1b..85683e3 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
@@ -220,7 +220,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Create two batches of PoissonLogNormalQuadratureCompounds, one with
   # prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.`
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index 18a0f75..134658d 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -196,8 +196,9 @@
   parameter determining the unnormalized probability of that component.
 
   ```python
-  tfd = tf.contrib.distributions
-  tfb = tfd.bijectors
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+  tfb = tfp.bijectors
 
   net = wavenet(inputs)
   loc, unconstrained_scale, logits = tf.split(net,
diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
index a9d0fb4..4b520b9 100644
--- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
+++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
@@ -124,7 +124,7 @@
       tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
       distribution: `tf.Distribution`-like instance. Distribution that is
         transformed to produce this distribution.
-        Default is `tf.distributions.Normal(0., 1.)`.
+        Default is `tfp.distributions.Normal(0., 1.)`.
         Must be a scalar-batch, scalar-event distribution.  Typically
         `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
         a function of non-trainable parameters. WARNING: If you backprop through
diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
index c25e8c5..af22f48 100644
--- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py
+++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
@@ -30,27 +30,27 @@
 `[0, 1]`.  Then you might do this:
 
 ```python
-tfd = tf.contrib.distributions
+  from tensorflow_probability.python.distributions.internal import statistical_testing
 
-expected_mean = ...
-num_samples = 5000
-samples = ... draw 5000 samples from P
+  expected_mean = ...
+  num_samples = 5000
+  samples = ... draw 5000 samples from P
 
-# Check that the mean looks right
-check1 = tfd.assert_true_mean_equal_by_dkwm(
-    samples, low=0., high=1., expected=expected_mean,
-    false_fail_rate=1e-6)
+  # Check that the mean looks right
+  check1 = statistical_testing.assert_true_mean_equal_by_dkwm(
+      samples, low=0., high=1., expected=expected_mean,
+      false_fail_rate=1e-6)
 
-# Check that the difference in means detectable with 5000 samples is
-# small enough
-check2 = tf.assert_less(
-    tfd.min_discrepancy_of_true_means_detectable_by_dkwm(
-        num_samples, low=0., high=1.0,
-        false_fail_rate=1e-6, false_pass_rate=1e-6),
-    0.01)
+  # Check that the difference in means detectable with 5000 samples is
+  # small enough
+  check2 = tf.assert_less(
+      statistical_testing.min_discrepancy_of_true_means_detectable_by_dkwm(
+          num_samples, low=0., high=1.0,
+          false_fail_rate=1e-6, false_pass_rate=1e-6),
+      0.01)
 
-# Be sure to execute both assertion ops
-sess.run([check1, check2])
+  # Be sure to execute both assertion ops
+  sess.run([check1, check2])
 ```
 
 The second assertion is an instance of experiment design.  It's a
diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
index ece03fe..a3d1783 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
@@ -23,7 +23,6 @@
 from tensorflow.contrib.distributions.python.ops import distribution_util
 from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator
 from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
-from tensorflow.contrib.linalg.python.ops import linear_operator_addition as linop_add_lib
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
@@ -36,6 +35,7 @@
 from tensorflow.python.ops.distributions import categorical as categorical_lib
 from tensorflow.python.ops.distributions import distribution as distribution_lib
 from tensorflow.python.ops.distributions import normal as normal_lib
+from tensorflow.python.ops.linalg import linear_operator_addition as linop_add_lib
 from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib
 from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib
 from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib
@@ -300,7 +300,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.],
   # another with mix_loc=[1]. In both cases, `K=2` and the affine
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
index 73356a3..36cbd71 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
@@ -90,7 +90,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Initialize a single 2-variate VectorExponential, supported on
   # {(x, y) in R^2 : x > 0, y > 0}.
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
index 9a47b48..fd5bf9e 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
@@ -108,7 +108,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Initialize a single 2-variate VectorExponential, supported on
   # {(x, y) in R^2 : x > 0, y > 0}.
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
index e68ddc5..8cd4e12 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
@@ -102,7 +102,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Initialize a single 2-variate VectorLaplace.
   vla = tfd.VectorLaplaceDiag(
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
index 3923161..67d2ccd 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
@@ -110,7 +110,8 @@
   #### Examples
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Initialize a single 3-variate VectorLaplace with some desired covariance.
   mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
index 49ffff2..da57d0c 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
@@ -152,7 +152,7 @@
         broadcastable with `event_shape`.
       distribution: `tf.Distribution`-like instance. Distribution from which `k`
         iid samples are used as input to transformation `F`.  Default is
-        `tf.distributions.Normal(loc=0., scale=1.)`.
+        `tfp.distributions.Normal(loc=0., scale=1.)`.
         Must be a scalar-batch, scalar-event distribution.  Typically
         `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
         a function of non-trainable parameters. WARNING: If you backprop through
diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
index f289b39e..bad91a0 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
@@ -92,7 +92,8 @@
   Extra leading dimensions, if provided, allow for batches.
 
   ```python
-  tfd = tf.contrib.distributions
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
 
   # Initialize a single 3-variate vector Student's t-distribution.
   mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py
index 49b9de0..ee2fc58 100644
--- a/tensorflow/contrib/distributions/python/ops/wishart.py
+++ b/tensorflow/contrib/distributions/python/ops/wishart.py
@@ -480,11 +480,14 @@
   #### Examples
 
   ```python
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+
   # Initialize a single 3x3 Wishart with Cholesky factored scale matrix and 5
   # degrees-of-freedom.(*)
   df = 5
   chol_scale = tf.cholesky(...)  # Shape is [3, 3].
-  dist = tf.contrib.distributions.WishartCholesky(df=df, scale=chol_scale)
+  dist = tfd.WishartCholesky(df=df, scale=chol_scale)
 
   # Evaluate this on an observation in R^3, returning a scalar.
   x = ...  # A 3x3 positive definite matrix.
@@ -498,14 +501,14 @@
   # Initialize two 3x3 Wisharts with Cholesky factored scale matrices.
   df = [5, 4]
   chol_scale = tf.cholesky(...)  # Shape is [2, 3, 3].
-  dist = tf.contrib.distributions.WishartCholesky(df=df, scale=chol_scale)
+  dist = tfd.WishartCholesky(df=df, scale=chol_scale)
 
   # Evaluate this on four observations.
   x = [[x0, x1], [x2, x3]]  # Shape is [2, 2, 3, 3].
   dist.prob(x)  # Shape is [2, 2].
 
   # (*) - To efficiently create a trainable covariance matrix, see the example
-  #   in tf.contrib.distributions.matrix_diag_transform.
+  #   in tfp.distributions.matrix_diag_transform.
   ```
 
   """
@@ -604,11 +607,14 @@
   #### Examples
 
   ```python
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+
   # Initialize a single 3x3 Wishart with Full factored scale matrix and 5
   # degrees-of-freedom.(*)
   df = 5
   scale = ...  # Shape is [3, 3]; positive definite.
-  dist = tf.contrib.distributions.WishartFull(df=df, scale=scale)
+  dist = tfd.WishartFull(df=df, scale=scale)
 
   # Evaluate this on an observation in R^3, returning a scalar.
   x = ...  # A 3x3 positive definite matrix.
@@ -622,14 +628,14 @@
   # Initialize two 3x3 Wisharts with Full factored scale matrices.
   df = [5, 4]
   scale = ...  # Shape is [2, 3, 3].
-  dist = tf.contrib.distributions.WishartFull(df=df, scale=scale)
+  dist = tfd.WishartFull(df=df, scale=scale)
 
   # Evaluate this on four observations.
   x = [[x0, x1], [x2, x3]]  # Shape is [2, 2, 3, 3]; xi is positive definite.
   dist.prob(x)  # Shape is [2, 2].
 
   # (*) - To efficiently create a trainable covariance matrix, see the example
-  #   in tf.contrib.distributions.matrix_diag_transform.
+  #   in tfd.matrix_diag_transform.
   ```
 
   """
diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md
index 86d2034..4bd2769 100644
--- a/tensorflow/contrib/eager/README.md
+++ b/tensorflow/contrib/eager/README.md
@@ -44,7 +44,6 @@
 
 For an introduction to eager execution in TensorFlow, see:
 
-- [User Guide](https://www.tensorflow.org/guide/eager) ([source](../../docs_src/guide/eager.md))
-- Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb)
-- Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb)
-- Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb)
+- [User Guide](https://www.tensorflow.org/guide/eager) ([source](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/index.md))
+- Notebook: [Basic Usage](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/eager_basics.ipynb)
+- Notebook: [Automatic differentiation and gradient tape](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb)
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index 7ed77bc..11f60c8 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -20,6 +20,7 @@
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.estimator import estimator
 from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
+from tensorflow.python.estimator.canned import head as head_lib
 
 
 def _validate_input_fn_and_repeat_dataset(train_input_fn):
@@ -33,7 +34,18 @@
   return _input_fn
 
 
-class _BoostedTreesEstimator(estimator.Estimator):
+# pylint: disable=protected-access
+def _is_classification_head(head):
+  """Infers if the head is a classification head."""
+  # Check using all classification heads defined in canned/head.py. However, it
+  # is not a complete list - it does not check for other classification heads
+  # not defined in the head library.
+  return isinstance(head,
+                    (head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss,
+                     head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss))
+
+
+class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase):
   """An Estimator for Tensorflow Boosted Trees models."""
 
   def __init__(self,
@@ -96,8 +108,10 @@
         negative gain). For pre and post pruning, you MUST provide
         tree_complexity >0.
 
+    Raises:
+      ValueError: when wrong arguments are given or unsupported functionalities
+         are requested.
     """
-    # pylint:disable=protected-access
     # HParams for the model.
     tree_hparams = canned_boosted_trees._TreeHParams(
         n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
@@ -115,8 +129,14 @@
           config=config)
 
     super(_BoostedTreesEstimator, self).__init__(
-        model_fn=_model_fn, model_dir=model_dir, config=config)
-    # pylint:enable=protected-access
+        model_fn=_model_fn,
+        model_dir=model_dir,
+        config=config,
+        feature_columns=feature_columns,
+        head=head,
+        center_bias=center_bias,
+        is_classification=_is_classification_head(head))
+    # pylint: enable=protected-access
 
 
 def boosted_trees_classifier_train_in_memory(
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
index b1581f3..e23d9c0 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
@@ -360,5 +360,79 @@
         [pred['predictions'] for pred in predictions])
 
 
+class BoostedTreesDebugOutputTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self._head = canned_boosted_trees._create_regression_head(label_dimension=1)
+    self._feature_columns = {
+        feature_column.bucketized_column(
+            feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+            BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
+    }
+
+  def testContribEstimatorThatDFCIsInPredictions(self):
+    # pylint:disable=protected-access
+    head = canned_boosted_trees._create_regression_head(label_dimension=1)
+    train_input_fn = _make_train_input_fn(is_classification=False)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees._BoostedTreesEstimator(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        head=head,
+        n_trees=1,
+        max_depth=5,
+        center_bias=True)
+    # pylint:enable=protected-access
+
+    num_steps = 100
+    # Train for a few steps. Validate debug outputs in prediction dicts.
+    est.train(train_input_fn, steps=num_steps)
+    debug_predictions = est.experimental_predict_with_explanations(
+        predict_input_fn)
+    biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+                         for pred in debug_predictions])
+    self.assertAllClose([1.8] * 5, biases)
+    self.assertAllClose(({
+        0: -0.070499420166015625,
+        1: -0.095000028610229492,
+        2: 0.0
+    }, {
+        0: -0.53763031959533691,
+        1: 0.063333392143249512,
+        2: 0.0
+    }, {
+        0: -0.51756942272186279,
+        1: -0.095000028610229492,
+        2: 0.0
+    }, {
+        0: 0.1563495397567749,
+        1: 0.063333392143249512,
+        2: 0.0
+    }, {
+        0: 0.96934974193572998,
+        1: 0.063333392143249512,
+        2: 0.0
+    }), dfcs)
+
+    # Assert sum(dfcs) + bias == predictions.
+    expected_predictions = [[1.6345005], [1.32570302], [1.1874305],
+                            [2.01968288], [2.83268309]]
+    predictions = [
+        [sum(dfc.values()) + bias] for (dfc, bias) in zip(dfcs, biases)
+    ]
+    self.assertAllClose(expected_predictions, predictions)
+
+    # Test when user doesn't include bias or dfc in predict_keys.
+    debug_predictions = est.experimental_predict_with_explanations(
+        predict_input_fn, predict_keys=['predictions'])
+    for prediction_dict in debug_predictions:
+      self.assertTrue('bias' in prediction_dict)
+      self.assertTrue('dfc' in prediction_dict)
+      self.assertTrue('predictions' in prediction_dict)
+      self.assertEqual(len(prediction_dict), 3)
+
+
 if __name__ == '__main__':
   googletest.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping.py b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
index 3eab21d..e6e25e3 100644
--- a/tensorflow/contrib/estimator/python/estimator/early_stopping.py
+++ b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
@@ -18,6 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
+import collections
 import operator
 import os
 
@@ -306,7 +307,8 @@
         metrics[value.tag] = value.simple_value
     if metrics:
       eval_metrics_dict[event.step] = metrics
-  return eval_metrics_dict
+  return collections.OrderedDict(
+      sorted(eval_metrics_dict.items(), key=lambda t: t[0]))
 
 
 def _stop_if_threshold_crossed_hook(estimator, metric_name, threshold,
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py
index 66c46e6..49f7bbd 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks.py
@@ -53,6 +53,7 @@
   ```
 
   Current limitations of this approach are:
+
   * It doesn't support multi-node distributed mode.
   * It doesn't support saveable objects other than variables (such as boosted
     tree support)
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
index b1820c1..9b0b9b1 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
@@ -186,7 +186,7 @@
           unexpected_shapes)
 
   def test_with_shape_2x2_with_partial_expected_shape(self):
-    with self.test_session():
+    with self.cached_session():
       value = [[42, 43], [44, 45]]
       actual_shape = [2, 2]
       tensor = constant_op.constant(value, shape=actual_shape)
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 716bb87..e9e6464 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -497,7 +497,8 @@
                                 FORMAT_OIHW, filter_param.shape(), FORMAT_HWIO),
                             &maybe_transformed_filter));
     functor::TransformFilter<GPUDevice, T, int, 4>()(
-        ctx->eigen_device<GPUDevice>(), To32Bit(filter_param.tensor<T, 4>()),
+        ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+        To32Bit(filter_param.tensor<T, 4>()),
         To32Bit(maybe_transformed_filter.tensor<T, 4>()));
     filter = &maybe_transformed_filter;
   }
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
index 0185ef6..e47342b 100644
--- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
@@ -265,7 +265,7 @@
     tensors = []
     for (data_format, use_gpu) in GetTestConfigs():
       tensors.append(_SetupVal(data_format, use_gpu))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values = sess.run(tensors)
       for i in range(1, len(values)):
         self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5)
@@ -282,7 +282,7 @@
               data_format, filter_format, dtype)
         tensors.append(result)
         ref_tensors.append(expected)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         values = sess.run(tensors)
         ref_values = sess.run(ref_tensors)
         for i in range(len(tensors)):
@@ -493,7 +493,7 @@
     if gpu_only and not test.is_gpu_available():
       tf_logging.info("Skipping OpEdgeCases tests.")
       return
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Illegal strides.
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "Convolutional strides are not supported in "
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
index d389748..8bc4db8 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
@@ -773,9 +773,9 @@
     structured_generator_inputs: A list of Tensors representing the random noise
       that must  have high mutual information with the generator output. List
       length should match `predicted_distributions`.
-    predicted_distributions: A list of tf.Distributions. Predicted by the
-      recognizer, and used to evaluate the likelihood of the structured noise.
-      List length should match `structured_generator_inputs`.
+    predicted_distributions: A list of `tfp.distributions.Distribution`s.
+      Predicted by the recognizer, and used to evaluate the likelihood of the
+      structured noise. List length should match `structured_generator_inputs`.
     weights: Optional `Tensor` whose rank is either 0, or the same dimensions as
       `structured_generator_inputs`.
     scope: The scope for the operations performed in computing the loss.
diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py
index a462b68..b9ac1bf 100644
--- a/tensorflow/contrib/gan/python/namedtuples.py
+++ b/tensorflow/contrib/gan/python/namedtuples.py
@@ -91,9 +91,9 @@
     structured_generator_inputs: A list of Tensors representing the random noise
       that must  have high mutual information with the generator output. List
       length should match `predicted_distributions`.
-    predicted_distributions: A list of tf.Distributions. Predicted by the
-      recognizer, and used to evaluate the likelihood of the structured noise.
-      List length should match `structured_generator_inputs`.
+    predicted_distributions: A list of `tfp.distributions.Distribution`s.
+      Predicted by the recognizer, and used to evaluate the likelihood of the
+      structured noise. List length should match `structured_generator_inputs`.
     discriminator_and_aux_fn: The original discriminator function that returns
       a tuple of (logits, `predicted_distributions`).
   """
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index 58f3480..64d6706 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -399,7 +399,7 @@
     target_tensor = train._generate_stargan_random_domain_target(
         batch_size, domain_numbers)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       targets = sess.run(target_tensor)
       self.assertTupleEqual((batch_size, domain_numbers), targets.shape)
       for target in targets:
@@ -676,7 +676,7 @@
 
     self.assertIsInstance(model_loss, namedtuples.GANLoss)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
 
       sess.run(variables.global_variables_initializer())
 
diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc
index 726f74c..bb06f1c 100644
--- a/tensorflow/contrib/gdr/gdr_memory_manager.cc
+++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc
@@ -138,6 +138,8 @@
       Device* device, DeviceContext* device_context, bool on_host,
       StatusCallback done) override;
 
+  static void RegMemVisitors();
+
  protected:
   Status CreateEndpoint(const string& host, const string& port,
                         RdmaEndpointPtr& endpoint);
@@ -183,35 +185,51 @@
   TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager);
 };
 
-// TODO(byronyi): remove this class and its registration when the default
-// cpu_allocator() returns visitable allocator, or cpu_allocator() is no
-// longer in use.
-class BFCGdrAllocator : public BFCAllocator {
- public:
-  BFCGdrAllocator()
-      : BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36,
-                     true, "cpu_gdr_bfc") {}
-};
-class BFCGdrAllocatorFactory : public AllocatorFactory {
- public:
-  Allocator* CreateAllocator() override { return new BFCGdrAllocator; }
-
-  virtual SubAllocator* CreateSubAllocator(int numa_node) {
-    return new BasicCPUAllocator(numa_node);
-  }
-};
-
-REGISTER_MEM_ALLOCATOR("BFCGdrAllocator", 102, BFCGdrAllocatorFactory);
-
 GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
     : host_(host),
       port_(port),
       listening_(nullptr, EndpointDeleter),
       stopped_(true),
-      next_key_(0) {}
+      next_key_(0) {
+  static std::once_flag flag;
+  std::call_once(flag, []() { RegMemVisitors(); });
+}
 
 GdrMemoryManager::~GdrMemoryManager() { close(epfd_); }
 
+/*static*/ void GdrMemoryManager::RegMemVisitors() {
+  SubAllocator::Visitor alloc_visitor = [](void* ptr, int numa_node,
+                                           size_t num_bytes) {
+    GdrMemoryManager::Singleton().InsertMemoryRegion(
+        ptr, num_bytes, strings::StrCat("CPU:", numa_node));
+  };
+  SubAllocator::Visitor free_visitor = [](void* ptr, int numa_node,
+                                          size_t num_bytes) {
+    GdrMemoryManager::Singleton().EvictMemoryRegion(ptr, num_bytes);
+  };
+  ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor);
+  ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
+
+#if GOOGLE_CUDA
+  if (IsGDRAvailable()) {
+    int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1;
+
+    // Note we don't free allocated GPU memory so there is no free visitor
+    SubAllocator::Visitor cuda_alloc_visitor = [](void* ptr, int gpu_id,
+                                                  size_t num_bytes) {
+      RdmaMemoryMgr::Singleton().InsertMemoryRegion(
+          ptr, num_bytes, strings::StrCat("GPU:", gpu_id));
+    };
+    GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
+                                                     cuda_alloc_visitor);
+    GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id,
+                                                          alloc_visitor);
+    GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor);
+    LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
+  }
+#endif  // GOOGLE_CUDA
+}
+
 Status GdrMemoryManager::Init() {
   epfd_ = epoll_create1(0);
   if (epfd_ == -1) {
@@ -271,48 +289,6 @@
                                "cannot add server to epoll");
   }
 
-  Allocator* allocators[] = {
-#if GOOGLE_CUDA
-    GPUProcessState::singleton()->GetCUDAHostAllocator(0),
-#endif  // GOOGLE_CUDA
-    ProcessState::singleton()->GetCPUAllocator(0),
-    cpu_allocator(),
-  };
-
-  using namespace std::placeholders;
-  VisitableAllocator::Visitor alloc_visitor =
-      std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
-  VisitableAllocator::Visitor free_visitor =
-      std::bind(&GdrMemoryManager::EvictMemoryRegion, this, _1, _2);
-
-  std::set<Allocator*> instrumented_;
-
-  // Host memory allocators
-  for (Allocator* allocator : allocators) {
-    auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
-    CHECK(visitable_allocator)
-        << "is not visitable for instrumentation" << allocator->Name();
-    // Make sure we don't instrument the same allocator twice
-    if (instrumented_.find(allocator) == std::end(instrumented_)) {
-      visitable_allocator->AddAllocVisitor(alloc_visitor);
-      visitable_allocator->AddFreeVisitor(free_visitor);
-      instrumented_.insert(allocator);
-      LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
-    }
-  }
-
-#if GOOGLE_CUDA
-  VisitableAllocator::Visitor cuda_alloc_visitor =
-      std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
-  if (IsGDRAvailable()) {
-    // Note we don't free allocated GPU memory so there is no free visitor
-    int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1;
-    GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
-                                                     cuda_alloc_visitor);
-    LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
-  }
-#endif  // GOOGLE_CUDA
-
   return Status::OK();
 }
 
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
index fed8a77..27aed09 100644
--- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -233,7 +233,7 @@
                     ([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
 
   def testGrid2LSTMCellWithRelu(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           'root', initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 3])
@@ -261,7 +261,7 @@
   """
 
   def testGrid2BasicRNNCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           'root', initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([2, 2])
@@ -292,7 +292,7 @@
                     [[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
 
   def testGrid2BasicRNNCellTied(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           'root', initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([2, 2])
@@ -323,7 +323,7 @@
                     [[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
 
   def testGrid2BasicRNNCellWithRelu(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           'root', initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -348,7 +348,7 @@
   """
 
   def testGrid1LSTMCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           'root', initializer=init_ops.constant_initializer(0.5)) as root_scope:
         x = array_ops.zeros([1, 3])
@@ -410,7 +410,7 @@
   """
 
   def testGrid3LSTMCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           'root', initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 3])
@@ -455,7 +455,7 @@
   """
 
   def testGridRNNEdgeCasesLikeRelu(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           'root', initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([3, 2])
@@ -481,7 +481,7 @@
         self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],))
 
   def testGridRNNEdgeCasesNoOutput(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           'root', initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -541,7 +541,7 @@
       self.assertEqual(out[0].get_shape()[1], num_units)
       self.assertEqual(out[0].dtype, inp.dtype)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
 
       input_value = np.ones((batch_size, input_size))
@@ -581,7 +581,7 @@
       self.assertEqual(out[0].get_shape()[1], num_units)
       self.assertEqual(out[0].dtype, inp.dtype)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
 
       input_value = np.ones((batch_size, input_size))
@@ -623,7 +623,7 @@
       self.assertEqual(out[0].get_shape()[1], num_units)
       self.assertEqual(out[0].dtype, inp.dtype)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
 
       input_value = np.ones((batch_size, input_size))
@@ -663,7 +663,7 @@
       self.assertEqual(out[0].get_shape(), (3, num_units))
       self.assertEqual(out[0].dtype, inp.dtype)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
 
       input_value = np.ones((batch_size, input_size))
@@ -700,7 +700,7 @@
       self.assertEqual(out[0].get_shape()[1], num_units)
       self.assertEqual(out[0].dtype, inp.dtype)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
 
       input_value = np.ones((3, input_size))
@@ -715,7 +715,7 @@
 
   def testGrid2LSTMCellLegacy(self):
     """Test for legacy case (when state_is_tuple=False)."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           'root', initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 3])
diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
index d796e43..f7f1189 100644
--- a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
+++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
@@ -51,7 +51,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for _ in range(num_repeats):  # Dataset is repeated.
         for i in range(25):  # 25 records.
diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
index 9ed0175..f44edaa 100644
--- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
+++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
@@ -29,7 +29,7 @@
 class InputPipelineOpsTest(test.TestCase):
 
   def testObtainNext(self):
-    with self.test_session():
+    with self.cached_session():
       var = state_ops.variable_op([], dtypes.int64)
       state_ops.assign(var, -1).op.run()
       c = constant_op.constant(["a", "b"])
@@ -45,7 +45,7 @@
 
   def testSeekNext(self):
     string_list = ["a", "b", "c"]
-    with self.test_session() as session:
+    with self.cached_session() as session:
       elem = input_pipeline_ops.seek_next(string_list)
       session.run([variables.global_variables_initializer()])
       self.assertEqual(b"a", session.run(elem))
@@ -65,7 +65,7 @@
 
   def testSeekNextLimitEpochs(self):
     string_list = ["a", "b", "c"]
-    with self.test_session() as session:
+    with self.cached_session() as session:
       elem = input_pipeline_ops.seek_next(string_list, num_epochs=1)
       session.run([
           variables.local_variables_initializer(),
@@ -75,7 +75,7 @@
 
   def testSeekNextLimitEpochsThree(self):
     string_list = ["a", "b", "c"]
-    with self.test_session() as session:
+    with self.cached_session() as session:
       elem = input_pipeline_ops.seek_next(string_list, num_epochs=3)
       session.run([
           variables.local_variables_initializer(),
diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py
index 6219118..08ebcdb 100644
--- a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py
+++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py
@@ -54,7 +54,7 @@
     init_batch_op = iterator.make_initializer(batch_dataset)
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Basic test: read from topic 0.
       sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1})
       for i in range(5):
diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py
index 7250753..4d5cc24 100644
--- a/tensorflow/contrib/kernel_methods/python/losses_test.py
+++ b/tensorflow/contrib/kernel_methods/python/losses_test.py
@@ -32,7 +32,7 @@
 
   def testInvalidLogitsShape(self):
     """An error is raised when logits have invalid shape."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([-1.0, 2.1], shape=(2,))
       labels = constant_op.constant([0, 1])
       with self.assertRaises(ValueError):
@@ -40,7 +40,7 @@
 
   def testInvalidLabelsShape(self):
     """An error is raised when labels have invalid shape."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
       labels = constant_op.constant([1, 0], shape=(1, 1, 2))
       with self.assertRaises(ValueError):
@@ -48,7 +48,7 @@
 
   def testInvalidWeightsShape(self):
     """An error is raised when weights have invalid shape."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
       labels = constant_op.constant([1, 0], shape=(2,))
       weights = constant_op.constant([1.5, 0.2], shape=(2, 1, 1))
@@ -57,7 +57,7 @@
 
   def testInvalidLabelsDtype(self):
     """An error is raised when labels have invalid shape."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
       labels = constant_op.constant([1, 0], dtype=dtypes.float32)
       with self.assertRaises(ValueError):
@@ -65,7 +65,7 @@
 
   def testNoneWeightRaisesValueError(self):
     """An error is raised when weights are None."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
       labels = constant_op.constant([1, 0])
       with self.assertRaises(ValueError):
@@ -73,7 +73,7 @@
 
   def testInconsistentLabelsAndWeightsShapesSameRank(self):
     """Error raised when weights and labels have same ranks, different sizes."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([-1.0, 2.1, 4.1], shape=(3, 1))
       labels = constant_op.constant([1, 0, 2], shape=(3, 1))
       weights = constant_op.constant([1.1, 2.0], shape=(2, 1))
@@ -82,7 +82,7 @@
 
   def testInconsistentLabelsAndWeightsShapesDifferentRank(self):
     """Error raised when weights and labels have different ranks and sizes."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
       labels = constant_op.constant([1, 0], shape=(2, 1))
       weights = constant_op.constant([1.1, 2.0, 2.8], shape=(3,))
@@ -91,7 +91,7 @@
 
   def testOutOfRangeLabels(self):
     """An error is raised when labels are not in [0, num_classes)."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
                                      [0.5, 1.8, -1.0]])
       labels = constant_op.constant([1, 0, 4])
@@ -101,7 +101,7 @@
 
   def testZeroLossInt32Labels(self):
     """Loss is 0 if true class logits sufficiently higher than other classes."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
                                      [0.5, 1.8, -1.0]])
       labels = constant_op.constant([0, 2, 1], dtype=dtypes.int32)
@@ -110,7 +110,7 @@
 
   def testZeroLossInt64Labels(self):
     """Loss is 0 if true class logits sufficiently higher than other classes."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[2.1, -0.4, -1.0], [1.4, 2.8, 4.0],
                                      [-0.5, 0.8, -1.0]])
       labels = constant_op.constant([0, 2, 1], dtype=dtypes.int64)
@@ -130,7 +130,7 @@
     ]
 
     for batch_size, num_classes in logits_shapes:
-      with self.test_session():
+      with self.cached_session():
         logits = array_ops.placeholder(
             dtypes.float32, shape=(batch_size, num_classes))
         labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,))
@@ -140,7 +140,7 @@
 
   def testCorrectPredictionsSomeClassesInsideMargin(self):
     """Loss is > 0 even if true class logits are higher than other classes."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[1.2, -1.4, 0.8], [1.4, 1.8, 4.0],
                                      [1.5, 1.8, -1.0]])
       labels = constant_op.constant([0, 2, 1])
@@ -150,7 +150,7 @@
 
   def testIncorrectPredictions(self):
     """Loss is >0 when an incorrect class has higher logits than true class."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[2.6, 0.4, 0.8], [1.4, 0.8, -1.0],
                                      [0.5, -1.8, 2.0]])
       labels = constant_op.constant([1, 0, 2])
@@ -162,7 +162,7 @@
 
   def testIncorrectPredictionsColumnLabels(self):
     """Same as above but labels is a rank-2 tensor."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
                                      [0.2, -1.8, 4.0]])
       labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -174,7 +174,7 @@
 
   def testIncorrectPredictionsZeroWeights(self):
     """Loss is 0 when all weights are missing even if predictions are wrong."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
                                      [0.2, -1.8, 4.0]])
       labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -185,7 +185,7 @@
 
   def testNonZeroLossWithPythonScalarWeights(self):
     """Weighted loss is correctly computed when weights is a python scalar."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
                                      [0.2, -1.8, 4.0]])
       labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -195,7 +195,7 @@
 
   def testNonZeroLossWithScalarTensorWeights(self):
     """Weighted loss is correctly computed when weights is a rank-0 tensor."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
                                      [0.2, -1.8, 4.0]])
       labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -205,7 +205,7 @@
 
   def testNonZeroLossWith1DTensorWeightsColumnLabels(self):
     """Weighted loss is correctly computed when weights is a rank-0 tensor."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
                                      [0.2, -1.8, 4.0]])
       labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -216,7 +216,7 @@
 
   def testNonZeroLossWith2DTensorWeights1DLabelsSomeWeightsMissing(self):
     """Weighted loss is correctly computed when weights is a rank-0 tensor."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
                                      [0.2, -1.8, 4.0], [1.6, 1.8, -4.0]])
       labels = constant_op.constant([1, 0, 2, 1])
diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
index 2ff4d41..bad0a59 100644
--- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
+++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
@@ -58,7 +58,7 @@
   def testInvalidInputShape(self):
     x = constant_op.constant([[2.0, 1.0]])
 
-    with self.test_session():
+    with self.cached_session():
       rffm = RandomFourierFeatureMapper(3, 10)
       with self.assertRaisesWithPredicateMatch(
           dense_kernel_mapper.InvalidShapeError,
@@ -70,7 +70,7 @@
     x2 = constant_op.constant([[1.0, -1.0, 2.0], [-1.0, 10.0, 1.0],
                                [4.0, -2.0, -1.0]])
 
-    with self.test_session():
+    with self.cached_session():
       rffm = RandomFourierFeatureMapper(3, 10, 1.0)
       mapped_x1 = rffm.map(x1)
       mapped_x2 = rffm.map(x2)
@@ -80,7 +80,7 @@
   def testSameOmegaReused(self):
     x = constant_op.constant([[2.0, 1.0, 0.0]])
 
-    with self.test_session():
+    with self.cached_session():
       rffm = RandomFourierFeatureMapper(3, 100)
       mapped_x = rffm.map(x)
       mapped_x_copy = rffm.map(x)
@@ -93,7 +93,7 @@
     y = constant_op.constant([[1.0, -1.0, 2.0]])
     stddev = 3.0
 
-    with self.test_session():
+    with self.cached_session():
       # The mapped dimension is fairly small, so the kernel approximation is
       # very rough.
       rffm1 = RandomFourierFeatureMapper(3, 100, stddev)
@@ -113,7 +113,7 @@
     y = constant_op.constant([[1.0, -1.0, 2.0]])
     stddev = 3.0
 
-    with self.test_session():
+    with self.cached_session():
       # The mapped dimension is fairly small, so the kernel approximation is
       # very rough.
       rffm = RandomFourierFeatureMapper(3, 100, stddev, seed=0)
@@ -139,7 +139,7 @@
 
     normalized_points = [nn.l2_normalize(point, dim=1) for point in points]
     total_absolute_error = 0.0
-    with self.test_session():
+    with self.cached_session():
       rffm = RandomFourierFeatureMapper(input_dim, mapped_dim, stddev, seed=0)
       # Cache mappings so that they are not computed multiple times.
       cached_mappings = dict((point, rffm.map(point))
diff --git a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
index 7289b45..bf89922 100644
--- a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
+++ b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
@@ -64,7 +64,7 @@
     init_batch_op = iterator.make_initializer(batch_dataset)
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Basic test: read from shard 0 of stream 1.
       sess.run(init_op, feed_dict={stream: stream_name, num_epochs: 1})
       for i in range(10):
@@ -108,7 +108,7 @@
     get_next = iterator.get_next()
 
     data = list()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Basic test: read from shard 0 of stream 2.
       sess.run(
           init_op, feed_dict={
diff --git a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py
index 28ddaa6..155d06a 100644
--- a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py
+++ b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py
@@ -45,7 +45,7 @@
         'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
         'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_dense(self):
@@ -66,7 +66,7 @@
         'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
         'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_integer_mixed_string_sparse(self):
@@ -80,7 +80,7 @@
         '333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2', '55555_X_batch2-FC2-F1',
         '55555_X_batch2-FC2-F2'
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_integer_mixed_string_dense(self):
@@ -99,7 +99,7 @@
         '55555_X_batch2-FC2-F1', '55555_X_batch2-FC2-F2',
         '999999_X_batch2-FC2-F1', '999999_X_batch2-FC2-F2'
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_sparse_cross_dense(self):
@@ -117,7 +117,7 @@
             'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
             'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
         ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_integer_sparse_input(self):
@@ -133,7 +133,7 @@
             '333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2',
             '5555_X_batch2-FC2-F1', '5555_X_batch2-FC2-F2'
         ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_permutation_3x3x3(self):
@@ -176,7 +176,7 @@
         'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F2',
         'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F3'
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_permutation_3x1x2(self):
@@ -196,7 +196,7 @@
         'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F1',
         'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F2'
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_large_batch(self):
@@ -229,7 +229,7 @@
       ])
 
     expected_out = self._sparse_tensor(col_out)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_one_column_empty(self):
@@ -242,7 +242,7 @@
         self._sparse_tensor([], 1),
         self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']])
     ])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_empty(sess.run(op))
 
   def test_some_columns_empty(self):
@@ -261,7 +261,7 @@
         'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F1',
         'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F2'
     ]], 2)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_all_columns_empty(self):
@@ -273,7 +273,7 @@
         self._sparse_tensor([]), self._sparse_tensor([]),
         self._sparse_tensor([])
     ])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_empty(sess.run(op))
 
   def test_hashed_output_zero_bucket(self):
@@ -288,7 +288,7 @@
         hashed_output=True)
     # Check actual hashed output to prevent unintentional hashing changes.
     expected_out = self._sparse_tensor([[3735511728867393167]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_hashed_output_zero_bucket_v2(self):
@@ -304,7 +304,7 @@
         hash_key=layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY)
     # Check actual hashed output to prevent unintentional hashing changes.
     expected_out = self._sparse_tensor([[1971693436396284976]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   # TODO(sibyl-Aix6ihai): Add benchmark to compare Hashed vs Non-hashed.
@@ -321,7 +321,7 @@
         num_buckets=100)
     # Check actual hashed output to prevent unintentional hashing changes.
     expected_out = self._sparse_tensor([[74]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_hashed_output_v2(self):
@@ -338,7 +338,7 @@
         hash_key=layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY)
     # Check actual hashed output to prevent unintentional hashing changes.
     expected_out = self._sparse_tensor([[83]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_hashed_output_v1_has_collision(self):
@@ -384,7 +384,7 @@
         ],
         hashed_output=True,
         num_buckets=1000)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       out = sess.run(op)
       self.assertEqual(6, len(out.values))
       self.assertAllEqual([[0, i] for i in range(6)], out.indices)
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
index d5c0212..33180b7 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
@@ -234,7 +234,7 @@
         self.assertTrue(test_ops.resource_initialized_op(handle).eval())
 
   def test_infer_different_default_graph(self):
-    with self.test_session():
+    with self.cached_session():
       self._assert_ckpt(self._output_dir, False)
       with ops.Graph().as_default():
         in0, in1, out = self._build_inference_graph()
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
index 2f33a2b..0e5ea6b 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
@@ -47,7 +47,7 @@
 class Seq2SeqTest(test.TestCase):
 
   def testRNNDecoder(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -65,7 +65,7 @@
         self.assertEqual((2, 2), res[0].shape)
 
   def testBasicRNNSeq2Seq(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -81,7 +81,7 @@
         self.assertEqual((2, 2), res[0].shape)
 
   def testTiedRNNSeq2Seq(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -98,7 +98,7 @@
         self.assertEqual((2, 2), res[0].shape)
 
   def testEmbeddingRNNDecoder(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -124,7 +124,7 @@
         self.assertEqual((2, 2), res[0].h.shape)
 
   def testEmbeddingRNNSeq2Seq(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         enc_inp = [
@@ -228,7 +228,7 @@
         self.assertAllClose(res1, res3)
 
   def testEmbeddingTiedRNNSeq2Seq(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         enc_inp = [
@@ -316,7 +316,7 @@
         self.assertAllClose(res1, res3)
 
   def testAttentionDecoder1(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -341,7 +341,7 @@
         self.assertEqual((2, 2), res[0].shape)
 
   def testAttentionDecoder2(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -367,7 +367,7 @@
         self.assertEqual((2, 2), res[0].shape)
 
   def testDynamicAttentionDecoder1(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -391,7 +391,7 @@
         self.assertEqual((2, 2), res[0].shape)
 
   def testDynamicAttentionDecoder2(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -416,7 +416,7 @@
         self.assertEqual((2, 2), res[0].shape)
 
   def testAttentionDecoderStateIsTuple(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         single_cell = lambda: rnn_cell.BasicLSTMCell(  # pylint: disable=g-long-lambda
@@ -448,7 +448,7 @@
         self.assertEqual((2, 2), res[0][1].h.shape)
 
   def testDynamicAttentionDecoderStateIsTuple(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         cell_fn = lambda: rnn_cell.MultiRNNCell(  # pylint: disable=g-long-lambda
@@ -479,7 +479,7 @@
         self.assertEqual((2, 2), res[0][1].h.shape)
 
   def testEmbeddingAttentionDecoder(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -513,7 +513,7 @@
         self.assertEqual((2, 2), res[0].shape)
 
   def testEmbeddingAttentionSeq2Seq(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         enc_inp = [
@@ -622,7 +622,7 @@
         # self.assertAllClose(res1, res3)
 
   def testOne2ManyRNNSeq2Seq(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         enc_inp = [
@@ -712,7 +712,7 @@
         self.assertAllClose(res1, res3)
 
   def testSequenceLoss(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       logits = [constant_op.constant(i + 0.5, shape=[2, 5]) for i in range(3)]
       targets = [
           constant_op.constant(
@@ -748,7 +748,7 @@
       self.assertAllClose(9.656628, res)
 
   def testSequenceLossByExample(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       output_classes = 5
       logits = [
           constant_op.constant(
@@ -778,7 +778,7 @@
   #   classes = 10
   #   buckets = [(4, 4), (8, 8)]
 
-  #   with self.test_session():
+  #   with self.cached_session():
   #     # Here comes a sample Seq2Seq model using GRU cells.
   #     def SampleGRUSeq2Seq(enc_inp, dec_inp, weights, per_example_loss):
   #       """Example sequence-to-sequence model that uses GRU cells."""
@@ -839,7 +839,7 @@
     random.seed(111)
     np.random.seed(111)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # We use sampled softmax so we keep output projection separate.
       w = variable_scope.get_variable("proj_w", [24, classes])
       w_t = array_ops.transpose(w)
diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD
deleted file mode 100644
index 78b7970..0000000
--- a/tensorflow/contrib/linalg/BUILD
+++ /dev/null
@@ -1,44 +0,0 @@
-# Description:
-#   Contains classes that provide access to common method of a [batch] matrix,
-#   without the need to instantiate the matrix.
-#   This allows for exploitation of structure, as well as a generic interface
-#   suitable for iterative solvers.
-
-licenses(["notice"])  # Apache 2.0
-
-exports_files(["LICENSE"])
-
-package(default_visibility = ["//tensorflow:__subpackages__"])
-
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-
-py_library(
-    name = "linalg_py",
-    srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
-    srcs_version = "PY2AND3",
-    deps = [
-        "//tensorflow/python:array_ops",
-        "//tensorflow/python:check_ops",
-        "//tensorflow/python:framework_for_generated_wrappers",
-        "//tensorflow/python:util",
-        "//tensorflow/python/ops/linalg",
-        "@six_archive//:six",
-    ],
-)
-
-cuda_py_test(
-    name = "linear_operator_addition_test",
-    size = "small",
-    srcs = ["python/kernel_tests/linear_operator_addition_test.py"],
-    additional_deps = [
-        ":linalg_py",
-        "//third_party/py/numpy",
-        "//tensorflow/python:array_ops",
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:framework",
-        "//tensorflow/python:framework_for_generated_wrappers",
-        "//tensorflow/python:framework_test_lib",
-        "//tensorflow/python:math_ops",
-        "//tensorflow/python:platform_test",
-    ],
-)
diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py
deleted file mode 100644
index cbe4c03..0000000
--- a/tensorflow/contrib/linalg/__init__.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Linear algebra libraries.
-
-See the[Contrib Linalg](https://tensorflow.org/api_guides/python/contrib.linalg)
-guide.
-
-@@LinearOperator
-@@LinearOperatorBlockDiag
-@@LinearOperatorCirculant
-@@LinearOperatorCirculant2D
-@@LinearOperatorCirculant3D
-@@LinearOperatorDiag
-@@LinearOperatorIdentity
-@@LinearOperatorScaledIdentity
-@@LinearOperatorFullMatrix
-@@LinearOperatorKronecker
-@@LinearOperatorLowerTriangular
-@@LinearOperatorLowRankUpdate
-@@LinearOperatorComposition
-@@add_operators
-
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
-
-from tensorflow.contrib.linalg.python.ops.linear_operator_addition import *
-from tensorflow.python.ops.linalg.linear_operator import *
-from tensorflow.python.ops.linalg.linear_operator_block_diag import *
-from tensorflow.python.ops.linalg.linear_operator_circulant import *
-from tensorflow.python.ops.linalg.linear_operator_composition import *
-from tensorflow.python.ops.linalg.linear_operator_diag import *
-from tensorflow.python.ops.linalg.linear_operator_full_matrix import *
-from tensorflow.python.ops.linalg.linear_operator_identity import *
-from tensorflow.python.ops.linalg.linear_operator_kronecker import *
-from tensorflow.python.ops.linalg.linear_operator_low_rank_update import *
-from tensorflow.python.ops.linalg.linear_operator_lower_triangular import *
-
-# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
-
-from tensorflow.python.util.all_util import remove_undocumented
-
-remove_undocumented(__name__)
diff --git a/tensorflow/contrib/linalg/python/__init__.py b/tensorflow/contrib/linalg/python/__init__.py
deleted file mode 100644
index c5ca3a6..0000000
--- a/tensorflow/contrib/linalg/python/__init__.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""ops module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py
deleted file mode 100644
index 6a72df6..0000000
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py
+++ /dev/null
@@ -1,412 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.linalg.python.ops import linear_operator_addition
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops.linalg import linalg as linalg_lib
-from tensorflow.python.platform import test
-
-linalg = linalg_lib
-random_seed.set_random_seed(23)
-rng = np.random.RandomState(0)
-
-add_operators = linear_operator_addition.add_operators
-
-
-# pylint: disable=unused-argument
-class _BadAdder(linear_operator_addition._Adder):
-  """Adder that will fail if used."""
-
-  def can_add(self, op1, op2):
-    raise AssertionError("BadAdder.can_add called!")
-
-  def _add(self, op1, op2, operator_name, hints):
-    raise AssertionError("This line should not be reached")
-
-
-# pylint: enable=unused-argument
-
-
-class LinearOperatorAdditionCorrectnessTest(test.TestCase):
-  """Tests correctness of addition with combinations of a few Adders.
-
-  Tests here are done with the _DEFAULT_ADDITION_TIERS, which means
-  add_operators should reduce all operators resulting in one single operator.
-
-  This shows that we are able to correctly combine adders using the tiered
-  system.  All Adders should be tested separately, and there is no need to test
-  every Adder within this class.
-  """
-
-  def test_one_operator_is_returned_unchanged(self):
-    op_a = linalg.LinearOperatorDiag([1., 1.])
-    op_sum = add_operators([op_a])
-    self.assertEqual(1, len(op_sum))
-    self.assertTrue(op_sum[0] is op_a)
-
-  def test_at_least_one_operators_required(self):
-    with self.assertRaisesRegexp(ValueError, "must contain at least one"):
-      add_operators([])
-
-  def test_attempting_to_add_numbers_raises(self):
-    with self.assertRaisesRegexp(TypeError, "contain only LinearOperator"):
-      add_operators([1, 2])
-
-  def test_two_diag_operators(self):
-    op_a = linalg.LinearOperatorDiag(
-        [1., 1.], is_positive_definite=True, name="A")
-    op_b = linalg.LinearOperatorDiag(
-        [2., 2.], is_positive_definite=True, name="B")
-    with self.test_session():
-      op_sum = add_operators([op_a, op_b])
-      self.assertEqual(1, len(op_sum))
-      op = op_sum[0]
-      self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
-      self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense().eval())
-      # Adding positive definite operators produces positive def.
-      self.assertTrue(op.is_positive_definite)
-      # Real diagonal ==> self-adjoint.
-      self.assertTrue(op.is_self_adjoint)
-      # Positive definite ==> non-singular
-      self.assertTrue(op.is_non_singular)
-      # Enforce particular name for this simple case
-      self.assertEqual("Add/B__A/", op.name)
-
-  def test_three_diag_operators(self):
-    op1 = linalg.LinearOperatorDiag(
-        [1., 1.], is_positive_definite=True, name="op1")
-    op2 = linalg.LinearOperatorDiag(
-        [2., 2.], is_positive_definite=True, name="op2")
-    op3 = linalg.LinearOperatorDiag(
-        [3., 3.], is_positive_definite=True, name="op3")
-    with self.test_session():
-      op_sum = add_operators([op1, op2, op3])
-      self.assertEqual(1, len(op_sum))
-      op = op_sum[0]
-      self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
-      self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
-      # Adding positive definite operators produces positive def.
-      self.assertTrue(op.is_positive_definite)
-      # Real diagonal ==> self-adjoint.
-      self.assertTrue(op.is_self_adjoint)
-      # Positive definite ==> non-singular
-      self.assertTrue(op.is_non_singular)
-
-  def test_diag_tril_diag(self):
-    op1 = linalg.LinearOperatorDiag(
-        [1., 1.], is_non_singular=True, name="diag_a")
-    op2 = linalg.LinearOperatorLowerTriangular(
-        [[2., 0.], [0., 2.]],
-        is_self_adjoint=True,
-        is_non_singular=True,
-        name="tril")
-    op3 = linalg.LinearOperatorDiag(
-        [3., 3.], is_non_singular=True, name="diag_b")
-    with self.test_session():
-      op_sum = add_operators([op1, op2, op3])
-      self.assertEqual(1, len(op_sum))
-      op = op_sum[0]
-      self.assertTrue(isinstance(op, linalg_lib.LinearOperatorLowerTriangular))
-      self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
-
-      # The diag operators will be self-adjoint (because real and diagonal).
-      # The TriL operator has the self-adjoint hint set.
-      self.assertTrue(op.is_self_adjoint)
-
-      # Even though op1/2/3 are non-singular, this does not imply op is.
-      # Since no custom hint was provided, we default to None (unknown).
-      self.assertEqual(None, op.is_non_singular)
-
-  def test_matrix_diag_tril_diag_uses_custom_name(self):
-    op0 = linalg.LinearOperatorFullMatrix(
-        [[-1., -1.], [-1., -1.]], name="matrix")
-    op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a")
-    op2 = linalg.LinearOperatorLowerTriangular(
-        [[2., 0.], [1.5, 2.]], name="tril")
-    op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
-    with self.test_session():
-      op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
-      self.assertEqual(1, len(op_sum))
-      op = op_sum[0]
-      self.assertTrue(isinstance(op, linalg_lib.LinearOperatorFullMatrix))
-      self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense().eval())
-      self.assertEqual("my_operator", op.name)
-
-  def test_incompatible_domain_dimensions_raises(self):
-    op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
-    op2 = linalg.LinearOperatorDiag(rng.rand(2, 4))
-    with self.assertRaisesRegexp(ValueError, "must.*same domain dimension"):
-      add_operators([op1, op2])
-
-  def test_incompatible_range_dimensions_raises(self):
-    op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
-    op2 = linalg.LinearOperatorDiag(rng.rand(3, 3))
-    with self.assertRaisesRegexp(ValueError, "must.*same range dimension"):
-      add_operators([op1, op2])
-
-  def test_non_broadcastable_batch_shape_raises(self):
-    op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3))
-    op2 = linalg.LinearOperatorDiag(rng.rand(4, 3, 3))
-    with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
-      add_operators([op1, op2])
-
-
-class LinearOperatorOrderOfAdditionTest(test.TestCase):
-  """Test that the order of addition is done as specified by tiers."""
-
-  def test_tier_0_additions_done_in_tier_0(self):
-    diag1 = linalg.LinearOperatorDiag([1.])
-    diag2 = linalg.LinearOperatorDiag([1.])
-    diag3 = linalg.LinearOperatorDiag([1.])
-    addition_tiers = [
-        [linear_operator_addition._AddAndReturnDiag()],
-        [_BadAdder()],
-    ]
-    # Should not raise since all were added in tier 0, and tier 1 (with the
-    # _BadAdder) was never reached.
-    op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers)
-    self.assertEqual(1, len(op_sum))
-    self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorDiag))
-
-  def test_tier_1_additions_done_by_tier_1(self):
-    diag1 = linalg.LinearOperatorDiag([1.])
-    diag2 = linalg.LinearOperatorDiag([1.])
-    tril = linalg.LinearOperatorLowerTriangular([[1.]])
-    addition_tiers = [
-        [linear_operator_addition._AddAndReturnDiag()],
-        [linear_operator_addition._AddAndReturnTriL()],
-        [_BadAdder()],
-    ]
-    # Should not raise since all were added by tier 1, and the
-    # _BadAdder) was never reached.
-    op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
-    self.assertEqual(1, len(op_sum))
-    self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular))
-
-  def test_tier_1_additions_done_by_tier_1_with_order_flipped(self):
-    diag1 = linalg.LinearOperatorDiag([1.])
-    diag2 = linalg.LinearOperatorDiag([1.])
-    tril = linalg.LinearOperatorLowerTriangular([[1.]])
-    addition_tiers = [
-        [linear_operator_addition._AddAndReturnTriL()],
-        [linear_operator_addition._AddAndReturnDiag()],
-        [_BadAdder()],
-    ]
-    # Tier 0 could convert to TriL, and this converted everything to TriL,
-    # including the Diags.
-    # Tier 1 was never used.
-    # Tier 2 was never used (therefore, _BadAdder didn't raise).
-    op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
-    self.assertEqual(1, len(op_sum))
-    self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular))
-
-  def test_cannot_add_everything_so_return_more_than_one_operator(self):
-    diag1 = linalg.LinearOperatorDiag([1.])
-    diag2 = linalg.LinearOperatorDiag([2.])
-    tril5 = linalg.LinearOperatorLowerTriangular([[5.]])
-    addition_tiers = [
-        [linear_operator_addition._AddAndReturnDiag()],
-    ]
-    # Tier 0 (the only tier) can only convert to Diag, so it combines the two
-    # diags, but the TriL is unchanged.
-    # Result should contain two operators, one Diag, one TriL.
-    op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers)
-    self.assertEqual(2, len(op_sum))
-    found_diag = False
-    found_tril = False
-    with self.test_session():
-      for op in op_sum:
-        if isinstance(op, linalg.LinearOperatorDiag):
-          found_diag = True
-          self.assertAllClose([[3.]], op.to_dense().eval())
-        if isinstance(op, linalg.LinearOperatorLowerTriangular):
-          found_tril = True
-          self.assertAllClose([[5.]], op.to_dense().eval())
-      self.assertTrue(found_diag and found_tril)
-
-  def test_intermediate_tier_is_not_skipped(self):
-    diag1 = linalg.LinearOperatorDiag([1.])
-    diag2 = linalg.LinearOperatorDiag([1.])
-    tril = linalg.LinearOperatorLowerTriangular([[1.]])
-    addition_tiers = [
-        [linear_operator_addition._AddAndReturnDiag()],
-        [_BadAdder()],
-        [linear_operator_addition._AddAndReturnTriL()],
-    ]
-    # tril cannot be added in tier 0, and the intermediate tier 1 with the
-    # BadAdder will catch it and raise.
-    with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"):
-      add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
-
-
-class AddAndReturnScaledIdentityTest(test.TestCase):
-
-  def setUp(self):
-    self._adder = linear_operator_addition._AddAndReturnScaledIdentity()
-
-  def test_identity_plus_identity(self):
-    id1 = linalg.LinearOperatorIdentity(num_rows=2)
-    id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
-    hints = linear_operator_addition._Hints(
-        is_positive_definite=True, is_non_singular=True)
-
-    self.assertTrue(self._adder.can_add(id1, id2))
-    operator = self._adder.add(id1, id2, "my_operator", hints)
-    self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
-
-    with self.test_session():
-      self.assertAllClose(2 *
-                          linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
-                          operator.to_dense().eval())
-    self.assertTrue(operator.is_positive_definite)
-    self.assertTrue(operator.is_non_singular)
-    self.assertEqual("my_operator", operator.name)
-
-  def test_identity_plus_scaled_identity(self):
-    id1 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
-    id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=2.2)
-    hints = linear_operator_addition._Hints(
-        is_positive_definite=True, is_non_singular=True)
-
-    self.assertTrue(self._adder.can_add(id1, id2))
-    operator = self._adder.add(id1, id2, "my_operator", hints)
-    self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
-
-    with self.test_session():
-      self.assertAllClose(3.2 *
-                          linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
-                          operator.to_dense().eval())
-    self.assertTrue(operator.is_positive_definite)
-    self.assertTrue(operator.is_non_singular)
-    self.assertEqual("my_operator", operator.name)
-
-  def test_scaled_identity_plus_scaled_identity(self):
-    id1 = linalg.LinearOperatorScaledIdentity(
-        num_rows=2, multiplier=[2.2, 2.2, 2.2])
-    id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=-1.0)
-    hints = linear_operator_addition._Hints(
-        is_positive_definite=True, is_non_singular=True)
-
-    self.assertTrue(self._adder.can_add(id1, id2))
-    operator = self._adder.add(id1, id2, "my_operator", hints)
-    self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
-
-    with self.test_session():
-      self.assertAllClose(1.2 *
-                          linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
-                          operator.to_dense().eval())
-    self.assertTrue(operator.is_positive_definite)
-    self.assertTrue(operator.is_non_singular)
-    self.assertEqual("my_operator", operator.name)
-
-
-class AddAndReturnDiagTest(test.TestCase):
-
-  def setUp(self):
-    self._adder = linear_operator_addition._AddAndReturnDiag()
-
-  def test_identity_plus_identity_returns_diag(self):
-    id1 = linalg.LinearOperatorIdentity(num_rows=2)
-    id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
-    hints = linear_operator_addition._Hints(
-        is_positive_definite=True, is_non_singular=True)
-
-    self.assertTrue(self._adder.can_add(id1, id2))
-    operator = self._adder.add(id1, id2, "my_operator", hints)
-    self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag))
-
-    with self.test_session():
-      self.assertAllClose(2 *
-                          linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
-                          operator.to_dense().eval())
-    self.assertTrue(operator.is_positive_definite)
-    self.assertTrue(operator.is_non_singular)
-    self.assertEqual("my_operator", operator.name)
-
-  def test_diag_plus_diag(self):
-    diag1 = rng.rand(2, 3, 4)
-    diag2 = rng.rand(4)
-    op1 = linalg.LinearOperatorDiag(diag1)
-    op2 = linalg.LinearOperatorDiag(diag2)
-    hints = linear_operator_addition._Hints(
-        is_positive_definite=True, is_non_singular=True)
-
-    self.assertTrue(self._adder.can_add(op1, op2))
-    operator = self._adder.add(op1, op2, "my_operator", hints)
-    self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag))
-
-    with self.test_session():
-      self.assertAllClose(
-          linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
-          operator.to_dense().eval())
-    self.assertTrue(operator.is_positive_definite)
-    self.assertTrue(operator.is_non_singular)
-    self.assertEqual("my_operator", operator.name)
-
-
-class AddAndReturnTriLTest(test.TestCase):
-
-  def setUp(self):
-    self._adder = linear_operator_addition._AddAndReturnTriL()
-
-  def test_diag_plus_tril(self):
-    diag = linalg.LinearOperatorDiag([1., 2.])
-    tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]])
-    hints = linear_operator_addition._Hints(
-        is_positive_definite=True, is_non_singular=True)
-
-    self.assertTrue(self._adder.can_add(diag, diag))
-    self.assertTrue(self._adder.can_add(diag, tril))
-    operator = self._adder.add(diag, tril, "my_operator", hints)
-    self.assertTrue(isinstance(operator, linalg.LinearOperatorLowerTriangular))
-
-    with self.test_session():
-      self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
-    self.assertTrue(operator.is_positive_definite)
-    self.assertTrue(operator.is_non_singular)
-    self.assertEqual("my_operator", operator.name)
-
-
-class AddAndReturnMatrixTest(test.TestCase):
-
-  def setUp(self):
-    self._adder = linear_operator_addition._AddAndReturnMatrix()
-
-  def test_diag_plus_diag(self):
-    diag1 = linalg.LinearOperatorDiag([1., 2.])
-    diag2 = linalg.LinearOperatorDiag([-1., 3.])
-    hints = linear_operator_addition._Hints(
-        is_positive_definite=False, is_non_singular=False)
-
-    self.assertTrue(self._adder.can_add(diag1, diag2))
-    operator = self._adder.add(diag1, diag2, "my_operator", hints)
-    self.assertTrue(isinstance(operator, linalg.LinearOperatorFullMatrix))
-
-    with self.test_session():
-      self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
-    self.assertFalse(operator.is_positive_definite)
-    self.assertFalse(operator.is_non_singular)
-    self.assertEqual("my_operator", operator.name)
-
-
-if __name__ == "__main__":
-  test.main()
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py
deleted file mode 100644
index 86130a2..0000000
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py
+++ /dev/null
@@ -1,432 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Add one or more `LinearOperators` efficiently."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-
-import six
-
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops.linalg import linear_operator
-from tensorflow.python.ops.linalg import linear_operator_diag
-from tensorflow.python.ops.linalg import linear_operator_full_matrix
-from tensorflow.python.ops.linalg import linear_operator_identity
-from tensorflow.python.ops.linalg import linear_operator_lower_triangular
-
-__all__ = []
-
-
-def add_operators(operators,
-                  operator_name=None,
-                  addition_tiers=None,
-                  name=None):
-  """Efficiently add one or more linear operators.
-
-  Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
-  operators `[B1, B2,...]` such that
-
-  ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
-
-  The operators `Bk` result by adding some of the `Ak`, as allowed by
-  `addition_tiers`.
-
-  Example of efficient adding of diagonal operators.
-
-  ```python
-  A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
-  A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
-
-  # Use two tiers, the first contains an Adder that returns Diag.  Since both
-  # A1 and A2 are Diag, they can use this Adder.  The second tier will not be
-  # used.
-  addition_tiers = [
-      [_AddAndReturnDiag()],
-      [_AddAndReturnMatrix()]]
-  B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
-
-  len(B_list)
-  ==> 1
-
-  B_list[0].__class__.__name__
-  ==> 'LinearOperatorDiag'
-
-  B_list[0].to_dense()
-  ==> [[3., 0.],
-       [0., 3.]]
-
-  B_list[0].name
-  ==> 'Add/A1__A2/'
-  ```
-
-  Args:
-    operators:  Iterable of `LinearOperator` objects with same `dtype`, domain
-      and range dimensions, and broadcastable batch shapes.
-    operator_name:  String name for returned `LinearOperator`.  Defaults to
-      concatenation of "Add/A__B/" that indicates the order of addition steps.
-    addition_tiers:  List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
-      is a list of `Adder` objects.  This function attempts to do all additions
-      in tier `i` before trying tier `i + 1`.
-    name:  A name for this `Op`.  Defaults to `add_operators`.
-
-  Returns:
-    Subclass of `LinearOperator`.  Class and order of addition may change as new
-      (and better) addition strategies emerge.
-
-  Raises:
-    ValueError:  If `operators` argument is empty.
-    ValueError:  If shapes are incompatible.
-  """
-  # Default setting
-  if addition_tiers is None:
-    addition_tiers = _DEFAULT_ADDITION_TIERS
-
-  # Argument checking.
-  check_ops.assert_proper_iterable(operators)
-  operators = list(reversed(operators))
-  if len(operators) < 1:
-    raise ValueError(
-        "Argument 'operators' must contain at least one operator.  "
-        "Found: %s" % operators)
-  if not all(
-      isinstance(op, linear_operator.LinearOperator) for op in operators):
-    raise TypeError(
-        "Argument 'operators' must contain only LinearOperator instances.  "
-        "Found: %s" % operators)
-  _static_check_for_same_dimensions(operators)
-  _static_check_for_broadcastable_batch_shape(operators)
-
-  graph_parents = []
-  for operator in operators:
-    graph_parents.extend(operator.graph_parents)
-
-  with ops.name_scope(name or "add_operators", values=graph_parents):
-
-    # Additions done in one of the tiers.  Try tier 0, 1,...
-    ops_to_try_at_next_tier = list(operators)
-    for tier in addition_tiers:
-      ops_to_try_at_this_tier = ops_to_try_at_next_tier
-      ops_to_try_at_next_tier = []
-      while ops_to_try_at_this_tier:
-        op1 = ops_to_try_at_this_tier.pop()
-        op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
-        if op2 is not None:
-          # Will try to add the result of this again at this same tier.
-          new_operator = adder.add(op1, op2, operator_name)
-          ops_to_try_at_this_tier.append(new_operator)
-        else:
-          ops_to_try_at_next_tier.append(op1)
-
-    return ops_to_try_at_next_tier
-
-
-def _pop_a_match_at_tier(op1, operator_list, tier):
-  # Search from the back of list to the front in order to create nice default
-  # order of operations.
-  for i in range(1, len(operator_list) + 1):
-    op2 = operator_list[-i]
-    for adder in tier:
-      if adder.can_add(op1, op2):
-        return operator_list.pop(-i), adder
-  return None, None
-
-
-def _infer_hints_allowing_override(op1, op2, hints):
-  """Infer hints from op1 and op2.  hints argument is an override.
-
-  Args:
-    op1:  LinearOperator
-    op2:  LinearOperator
-    hints:  _Hints object holding "is_X" boolean hints to use for returned
-      operator.
-      If some hint is None, try to set using op1 and op2.  If the
-      hint is provided, ignore op1 and op2 hints.  This allows an override
-      of previous hints, but does not allow forbidden hints (e.g. you still
-      cannot say a real diagonal operator is not self-adjoint.
-
-  Returns:
-    _Hints object.
-  """
-  hints = hints or _Hints()
-  # If A, B are self-adjoint, then so is A + B.
-  if hints.is_self_adjoint is None:
-    is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
-  else:
-    is_self_adjoint = hints.is_self_adjoint
-
-  # If A, B are positive definite, then so is A + B.
-  if hints.is_positive_definite is None:
-    is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
-  else:
-    is_positive_definite = hints.is_positive_definite
-
-  # A positive definite operator is always non-singular.
-  if is_positive_definite and hints.is_positive_definite is None:
-    is_non_singular = True
-  else:
-    is_non_singular = hints.is_non_singular
-
-  return _Hints(
-      is_non_singular=is_non_singular,
-      is_self_adjoint=is_self_adjoint,
-      is_positive_definite=is_positive_definite)
-
-
-def _static_check_for_same_dimensions(operators):
-  """ValueError if operators determined to have different dimensions."""
-  if len(operators) < 2:
-    return
-
-  domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators
-                       if op.domain_dimension.value is not None]
-  if len(set(value for name, value in domain_dimensions)) > 1:
-    raise ValueError("Operators must have the same domain dimension. Found: %s"
-                     % domain_dimensions)
-
-  range_dimensions = [(op.name, op.range_dimension.value) for op in operators
-                      if op.range_dimension.value is not None]
-  if len(set(value for name, value in range_dimensions)) > 1:
-    raise ValueError("Operators must have the same range dimension. Found: %s" %
-                     range_dimensions)
-
-
-def _static_check_for_broadcastable_batch_shape(operators):
-  """ValueError if operators determined to have non-broadcastable shapes."""
-  if len(operators) < 2:
-    return
-
-  # This will fail if they cannot be broadcast together.
-  batch_shape = operators[0].batch_shape
-  for op in operators[1:]:
-    batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
-
-
-class _Hints(object):
-  """Holds 'is_X' flags that every LinearOperator is initialized with."""
-
-  def __init__(self,
-               is_non_singular=None,
-               is_positive_definite=None,
-               is_self_adjoint=None):
-    self.is_non_singular = is_non_singular
-    self.is_positive_definite = is_positive_definite
-    self.is_self_adjoint = is_self_adjoint
-
-
-################################################################################
-# Classes to add two linear operators.
-################################################################################
-
-
-@six.add_metaclass(abc.ABCMeta)
-class _Adder(object):
-  """Abstract base class to add two operators.
-
-  Each `Adder` acts independently, adding everything it can, paying no attention
-  as to whether another `Adder` could have done the addition more efficiently.
-  """
-
-  @property
-  def name(self):
-    return self.__class__.__name__
-
-  @abc.abstractmethod
-  def can_add(self, op1, op2):
-    """Returns `True` if this `Adder` can add `op1` and `op2`.  Else `False`."""
-    pass
-
-  @abc.abstractmethod
-  def _add(self, op1, op2, operator_name, hints):
-    # Derived classes can assume op1 and op2 have been validated, e.g. they have
-    # the same dtype, and their domain/range dimensions match.
-    pass
-
-  def add(self, op1, op2, operator_name, hints=None):
-    """Return new `LinearOperator` acting like `op1 + op2`.
-
-    Args:
-      op1:  `LinearOperator`
-      op2:  `LinearOperator`, with `shape` and `dtype` such that adding to
-        `op1` is allowed.
-      operator_name:  `String` name to give to returned `LinearOperator`
-      hints:  `_Hints` object.  Returned `LinearOperator` will be created with
-        these hints.
-
-    Returns:
-      `LinearOperator`
-    """
-    updated_hints = _infer_hints_allowing_override(op1, op2, hints)
-
-    if operator_name is None:
-      operator_name = "Add/" + op1.name + "__" + op2.name + "/"
-
-    values = op1.graph_parents + op2.graph_parents
-    scope_name = self.name
-    if scope_name.startswith("_"):
-      scope_name = scope_name[1:]
-    with ops.name_scope(scope_name, values=values):
-      return self._add(op1, op2, operator_name, updated_hints)
-
-
-class _AddAndReturnScaledIdentity(_Adder):
-  """Handles additions resulting in an Identity family member.
-
-  The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
-  is closed under addition.  This `Adder` respects that, and returns an Identity
-  """
-
-  def can_add(self, op1, op2):
-    types = {_type(op1), _type(op2)}
-    return not types.difference(_IDENTITY_FAMILY)
-
-  def _add(self, op1, op2, operator_name, hints):
-    # Will build a LinearOperatorScaledIdentity.
-
-    if _type(op1) == _SCALED_IDENTITY:
-      multiplier_1 = op1.multiplier
-    else:
-      multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
-
-    if _type(op2) == _SCALED_IDENTITY:
-      multiplier_2 = op2.multiplier
-    else:
-      multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
-
-    return linear_operator_identity.LinearOperatorScaledIdentity(
-        num_rows=op1.range_dimension_tensor(),
-        multiplier=multiplier_1 + multiplier_2,
-        is_non_singular=hints.is_non_singular,
-        is_self_adjoint=hints.is_self_adjoint,
-        is_positive_definite=hints.is_positive_definite,
-        name=operator_name)
-
-
-class _AddAndReturnDiag(_Adder):
-  """Handles additions resulting in a Diag operator."""
-
-  def can_add(self, op1, op2):
-    types = {_type(op1), _type(op2)}
-    return not types.difference(_DIAG_LIKE)
-
-  def _add(self, op1, op2, operator_name, hints):
-    return linear_operator_diag.LinearOperatorDiag(
-        diag=op1.diag_part() + op2.diag_part(),
-        is_non_singular=hints.is_non_singular,
-        is_self_adjoint=hints.is_self_adjoint,
-        is_positive_definite=hints.is_positive_definite,
-        name=operator_name)
-
-
-class _AddAndReturnTriL(_Adder):
-  """Handles additions resulting in a TriL operator."""
-
-  def can_add(self, op1, op2):
-    types = {_type(op1), _type(op2)}
-    return not types.difference(_DIAG_LIKE.union({_TRIL}))
-
-  def _add(self, op1, op2, operator_name, hints):
-    if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
-      op_add_to_tensor, op_other = op1, op2
-    else:
-      op_add_to_tensor, op_other = op2, op1
-
-    return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
-        tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
-        is_non_singular=hints.is_non_singular,
-        is_self_adjoint=hints.is_self_adjoint,
-        is_positive_definite=hints.is_positive_definite,
-        name=operator_name)
-
-
-class _AddAndReturnMatrix(_Adder):
-  """"Handles additions resulting in a `LinearOperatorFullMatrix`."""
-
-  def can_add(self, op1, op2):  # pylint: disable=unused-argument
-    return isinstance(op1, linear_operator.LinearOperator) and isinstance(
-        op2, linear_operator.LinearOperator)
-
-  def _add(self, op1, op2, operator_name, hints):
-    if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
-      op_add_to_tensor, op_other = op1, op2
-    else:
-      op_add_to_tensor, op_other = op2, op1
-    return linear_operator_full_matrix.LinearOperatorFullMatrix(
-        matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
-        is_non_singular=hints.is_non_singular,
-        is_self_adjoint=hints.is_self_adjoint,
-        is_positive_definite=hints.is_positive_definite,
-        name=operator_name)
-
-
-################################################################################
-# Constants designating types of LinearOperators
-################################################################################
-
-# Type name constants for LinearOperator classes.
-_IDENTITY = "identity"
-_SCALED_IDENTITY = "scaled_identity"
-_DIAG = "diag"
-_TRIL = "tril"
-_MATRIX = "matrix"
-
-# Groups of operators.
-_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
-_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
-# operators with an efficient .add_to_tensor() method.
-_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
-
-
-def _type(operator):
-  """Returns the type name constant (e.g. _TRIL) for operator."""
-  if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
-    return _DIAG
-  if isinstance(operator,
-                linear_operator_lower_triangular.LinearOperatorLowerTriangular):
-    return _TRIL
-  if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
-    return _MATRIX
-  if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
-    return _IDENTITY
-  if isinstance(operator,
-                linear_operator_identity.LinearOperatorScaledIdentity):
-    return _SCALED_IDENTITY
-  raise TypeError("Operator type unknown: %s" % operator)
-
-
-################################################################################
-# Addition tiers:
-# We attempt to use Adders in tier K before K+1.
-#
-# Organize tiers to
-#   (i) reduce O(..) complexity of forming final operator, and
-#   (ii) produce the "most efficient" final operator.
-# Dev notes:
-#  * Results of addition at tier K will be added at tier K or higher.
-#  * Tiers may change, and we warn the user that it may change.
-################################################################################
-
-# Note that the final tier, _AddAndReturnMatrix, will convert everything to a
-# dense matrix.  So it is sometimes very inefficient.
-_DEFAULT_ADDITION_TIERS = [
-    [_AddAndReturnScaledIdentity()],
-    [_AddAndReturnDiag()],
-    [_AddAndReturnTriL()],
-    [_AddAndReturnMatrix()],
-]
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index 1d2db1c..9ecf023 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -134,7 +134,7 @@
   return examples_dict, variables_dict
 
 
-def make_variable_dict(max_age, max_gender, partitioned=False):
+def make_variable_dict(max_age, max_gender, num_shards=None, partitioned=False):
   # TODO(sibyl-toe9oF2e):  Figure out how to derive max_age & max_gender from
   # examples_dict.
   partitioner = None
@@ -142,14 +142,15 @@
     partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2,
                                                                axis=0)
   with variable_scope.variable_scope(
-      name_or_scope='variables',
+      name_or_scope=('variables/shard_{}'.format(num_shards)
+                     if num_shards else 'variables'),
       partitioner=partitioner):
-    age_weights = variables_lib.Variable(
-        array_ops.zeros(
-            [max_age + 1], dtype=dtypes.float32))
-    gender_weights = variables_lib.Variable(
-        array_ops.zeros(
-            [max_gender + 1], dtype=dtypes.float32))
+    age_weights = variable_scope.get_variable(
+        name='age',
+        initializer=array_ops.zeros([max_age + 1], dtype=dtypes.float32))
+    gender_weights = variable_scope.get_variable(
+        name='gender',
+        initializer=array_ops.zeros([max_gender + 1], dtype=dtypes.float32))
   return dict(
       sparse_features_weights=[age_weights, gender_weights],
       dense_features_weights=[])
@@ -242,7 +243,7 @@
     for num_shards in _SHARD_NUMBERS:
       with self._single_threaded_test_session():
         examples = make_example_dict(example_protos, example_weights)
-        variables = make_variable_dict(1, 1)
+        variables = make_variable_dict(1, 1, num_shards)
         options = dict(
             symmetric_l2_regularization=1,
             symmetric_l1_regularization=0,
@@ -290,7 +291,7 @@
     for num_shards in _SHARD_NUMBERS:
       with self._single_threaded_test_session():
         examples = make_example_dict(example_protos, example_weights)
-        variables = make_variable_dict(1, 1, partitioned=True)
+        variables = make_variable_dict(1, 1, num_shards, partitioned=True)
         options = dict(
             symmetric_l2_regularization=1,
             symmetric_l1_regularization=0,
@@ -322,6 +323,68 @@
         self.assertAllClose(
             0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
 
+  def testSomePartitionedPrimals(self):
+    # Setup test data
+    example_protos = [
+        make_example_proto({
+            'age': [0],
+            'gender': [0]
+        }, 0),
+        make_example_proto({
+            'age': [0],
+            'gender': [1]
+        }, 1),
+    ]
+    example_weights = [1.0, 1.0]
+    for num_shards in _SHARD_NUMBERS:
+      with self._single_threaded_test_session():
+        examples = make_example_dict(example_protos, example_weights)
+        # Explicitly make age a [1]-shaped Variable (which cannot be
+        # partitioned), while making gender a PartitionedVariable.
+        age_weights = variables_lib.Variable(
+            array_ops.zeros([1], dtype=dtypes.float32))
+        with variable_scope.variable_scope(
+            name_or_scope=('variables/shard_{}'.format(num_shards)
+                           if num_shards else 'variables'),
+            partitioner=partitioned_variables.fixed_size_partitioner(
+                num_shards=2, axis=0)):
+          gender_weights = variable_scope.get_variable(
+              name='gender',
+              initializer=array_ops.zeros([2], dtype=dtypes.float32))
+        variables = dict(
+            sparse_features_weights=[age_weights, gender_weights],
+            dense_features_weights=[])
+        options = dict(
+            symmetric_l2_regularization=1,
+            symmetric_l1_regularization=0,
+            num_table_shards=num_shards,
+            loss_type='logistic_loss')
+
+        lr = SdcaModel(examples, variables, options)
+        variables_lib.global_variables_initializer().run()
+        unregularized_loss = lr.unregularized_loss(examples)
+        loss = lr.regularized_loss(examples)
+        predictions = lr.predictions(examples)
+        self.assertAllClose(0.693147, unregularized_loss.eval())
+        self.assertAllClose(0.693147, loss.eval())
+        train_op = lr.minimize()
+        for _ in range(_MAX_ITERATIONS):
+          train_op.run()
+        lr.update_weights(train_op).run()
+        # The high tolerance in unregularized_loss comparisons is due to the
+        # fact that it's possible to trade off unregularized_loss vs.
+        # regularization and still have a sum that is quite close to the
+        # optimal regularized_loss value.  SDCA's duality gap only ensures that
+        # the regularized_loss is within 0.01 of optimal.
+        # 0.525457 is the optimal regularized_loss.
+        # 0.593014 is the unregularized_loss at that optimum.
+        self.assertAllClose(0.512591, unregularized_loss.eval(), atol=0.05)
+        self.assertAllClose(0.593014, loss.eval(), atol=0.01)
+        predicted_labels = get_binary_predictions_for_logistic(predictions)
+        self.assertAllEqual([0, 1], predicted_labels.eval())
+        self.assertAllClose(
+            0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+
   def testSparseRandom(self):
     dim = 20
     num_examples = 1000
@@ -463,7 +526,7 @@
     for num_shards in _SHARD_NUMBERS:
       with self._single_threaded_test_session():
         examples = make_example_dict(example_protos, example_weights)
-        variables = make_variable_dict(1, 1)
+        variables = make_variable_dict(1, 1, num_shards)
         options = dict(
             symmetric_l2_regularization=0,
             symmetric_l1_regularization=0,
@@ -521,7 +584,7 @@
       with self._single_threaded_test_session():
         # Only use examples 0 and 2
         examples = make_example_dict(example_protos, example_weights)
-        variables = make_variable_dict(1, 1)
+        variables = make_variable_dict(1, 1, num_shards)
         options = dict(
             symmetric_l2_regularization=1,
             symmetric_l1_regularization=0,
@@ -561,7 +624,7 @@
     for num_shards in _SHARD_NUMBERS:
       with self._single_threaded_test_session():
         examples = make_example_dict(example_protos, example_weights)
-        variables = make_variable_dict(1, 1)
+        variables = make_variable_dict(1, 1, num_shards)
         options = dict(
             symmetric_l2_regularization=1,
             symmetric_l1_regularization=0,
@@ -598,7 +661,7 @@
     for num_shards in _SHARD_NUMBERS:
       with self._single_threaded_test_session():
         examples = make_example_dict(example_protos, example_weights)
-        variables = make_variable_dict(3, 1)
+        variables = make_variable_dict(3, 1, num_shards)
         options = dict(
             symmetric_l2_regularization=1,
             symmetric_l1_regularization=0,
@@ -639,7 +702,7 @@
     for num_shards in _SHARD_NUMBERS:
       with self._single_threaded_test_session():
         examples = make_example_dict(example_protos, example_weights)
-        variables = make_variable_dict(1, 1)
+        variables = make_variable_dict(1, 1, num_shards)
         options = dict(
             symmetric_l2_regularization=1,
             symmetric_l1_regularization=0,
@@ -679,7 +742,7 @@
     for num_shards in _SHARD_NUMBERS:
       with self._single_threaded_test_session():
         examples = make_example_dict(example_protos, example_weights)
-        variables = make_variable_dict(1, 1)
+        variables = make_variable_dict(1, 1, num_shards)
         options = dict(
             symmetric_l2_regularization=1,
             symmetric_l1_regularization=0,
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 14f59a3..b98adf8 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -400,14 +400,16 @@
 
       sparse_weights = []
       sparse_indices = []
-      # If we have partitioned variables, keep a few lists of Tensors around
-      # that we need for the assign_add after the op call to
-      # gen_sdca_ops.sdca_optimizer().
-      num_partitions_by_var = []
-      p_assignments_by_var = []
-      gather_ids_by_var = []
-      for w, i in zip(self._slots['unshrinked_sparse_features_weights'],
-                      sparse_feature_indices):
+      # If we have partitioned variables, keep a few dictionaries of Tensors
+      # around that we need for the assign_add after the op call to
+      # gen_sdca_ops.sdca_optimizer().  These are keyed because we may have a
+      # mix of partitioned and un-partitioned variables.
+      num_partitions_by_var = {}
+      p_assignments_by_var = {}
+      gather_ids_by_var = {}
+      for v_num, (w, i) in enumerate(
+          zip(self._slots['unshrinked_sparse_features_weights'],
+              sparse_feature_indices)):
         # Append the sparse_indices (in full-variable space).
         sparse_idx = math_ops.cast(
             array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
@@ -456,10 +458,10 @@
           gather_ids = data_flow_ops.dynamic_partition(new_ids,
                                                        p_assignments,
                                                        num_partitions)
-          # Append these to the lists for use in the later update.
-          num_partitions_by_var.append(num_partitions)
-          p_assignments_by_var.append(p_assignments)
-          gather_ids_by_var.append(gather_ids)
+          # Add these into the dictionaries for use in the later update.
+          num_partitions_by_var[v_num] = num_partitions
+          p_assignments_by_var[v_num] = p_assignments
+          gather_ids_by_var[v_num] = gather_ids
 
           # Gather the weights from each partition.
           partition_gathered_weights = []
diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md
index a676b70..a4b3d83 100644
--- a/tensorflow/contrib/lite/README.md
+++ b/tensorflow/contrib/lite/README.md
@@ -4,5 +4,5 @@
 devices. It enables low-latency inference of on-device machine learning models
 with a small binary size and fast performance supporting hardware acceleration.
 
-See the documentation: https://www.tensorflow.org/mobile/tflite/
-Documentation edits can be made here: [tensorflow/docs_src/mobile/tflite](../../docs_src/mobile/tflite)
+See the documentation: https://www.tensorflow.org/lite/
+Documentation edits can be made here: [tensorflow/contrib/lite/g3doc](./g3doc/)
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 5c705ea..fc4d9b4 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -294,6 +294,7 @@
         #"transpose_conv",   # disabled due to b/111213074
         "unpack",
         "where",
+        "zeros_like",
     ]
 
 def generated_test_conversion_modes():
@@ -337,11 +338,7 @@
         flags = "--ignore_toco_errors --run_with_extended"
         kwargs["tags"].append("skip_already_failing")
         kwargs["tags"].append("no_oss")
-
-        # TODO(b/115504899): Re-enable asan, msan and tsan tests.
-        kwargs["tags"].append("noasan")
-        kwargs["tags"].append("nomsan")
-        kwargs["tags"].append("notsan")
+        kwargs["tags"].append("notap")
 
     gen_zipped_test_file(
         name = "zip_%s" % test_name,
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 5e97b77..7809d11 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -118,6 +118,8 @@
   kTfLiteBuiltinFloorDiv = 90,
   kTfLiteBuiltinReduceAny = 91,
   kTfLiteBuiltinSquare = 92,
+  kTfLiteBuiltinZerosLike = 93,
+  kTfLiteBuiltinFill = 94,
 } TfLiteBuiltinOperator;
 
 #ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index f4d2839..03af538 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -618,6 +618,8 @@
     case BuiltinOperator_LOGICAL_NOT:
     case BuiltinOperator_FLOOR_DIV:
     case BuiltinOperator_SQUARE:
+    case BuiltinOperator_ZEROS_LIKE:
+    case BuiltinOperator_FILL:
       break;
   }
   return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
index c6587b3..d85e576 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -518,7 +518,7 @@
         }
         break;
       case kTfLiteBuiltinReshape:
-        if (version == 1) {
+        if (version == 1 && node->inputs->size == 2) {
           return [](const NNAPIOpMappingArgs& mapping_args)
                      -> ANeuralNetworksOperationType {
             return ANEURALNETWORKS_RESHAPE;
diff --git a/tensorflow/contrib/lite/examples/android/app/README.md b/tensorflow/contrib/lite/examples/android/app/README.md
index cbdeeac..dc31171 100644
--- a/tensorflow/contrib/lite/examples/android/app/README.md
+++ b/tensorflow/contrib/lite/examples/android/app/README.md
@@ -2,7 +2,7 @@
 
 ## Building from Source with Bazel
 
-1. Install [Bazel](https://docs.bazel.build/versions/master/install.html), the Android NDK and SDK. The recommended versions are specified on this [webpage](https://www.tensorflow.org/mobile/tflite/demo_android#build_tensorflow_lite_and_the_demo_app_from_source).
+1. Install [Bazel](https://docs.bazel.build/versions/master/install.html), the Android NDK and SDK. The recommended versions are specified on this [webpage](https://www.tensorflow.org/lite/demo_android).
 
 2. Build this demo app with Bazel. The demo needs C++11. We configure the fat_apk_cpu flag to package support for 4 hardware variants. You may replace it with --config=android_arm64 on a 64-bit device and --config=android_arm for 32-bit device:
 
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD
index ea4a543..52e7161 100644
--- a/tensorflow/contrib/lite/experimental/c/BUILD
+++ b/tensorflow/contrib/lite/experimental/c/BUILD
@@ -1,5 +1,12 @@
 package(default_visibility = ["//visibility:private"])
 
+package_group(
+    name = "experimental",
+    packages = [
+        "//tensorflow/contrib/lite/experimental/...",
+    ],
+)
+
 licenses(["notice"])  # Apache 2.0
 
 load(
@@ -51,6 +58,9 @@
     srcs = ["c_api.cc"],
     hdrs = ["c_api.h"],
     copts = tflite_copts(),
+    visibility = [
+        ":experimental",
+    ],
     deps = [
         ":c_api_internal",
         "//tensorflow/contrib/lite:context",
@@ -68,6 +78,7 @@
     deps = [
         ":c_api",
         ":c_api_internal",
+        "//tensorflow/contrib/lite:kernel_api",
     ],
 )
 
@@ -93,6 +104,7 @@
     deps = [
         ":c_api",
         ":c_api_experimental",
+        "//tensorflow/contrib/lite:kernel_api",
         "//tensorflow/contrib/lite/testing:util",
         "@com_google_googletest//:gtest",
     ],
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc
index c589cf7..9c29f9d 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api.cc
@@ -17,6 +17,7 @@
 #include <memory>
 
 #include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
 #include "tensorflow/contrib/lite/experimental/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/interpreter.h"
 #include "tensorflow/contrib/lite/kernels/register.h"
@@ -26,6 +27,26 @@
 extern "C" {
 #endif  // __cplusplus
 
+namespace {
+class CallbackErrorReporter : public tflite::ErrorReporter {
+ public:
+  using ErrorCallback = void (*)(void* user_data, const char* format,
+                                 va_list args);
+
+  CallbackErrorReporter(ErrorCallback callback, void* user_data)
+      : callback_(callback), user_data_(user_data) {}
+
+  int Report(const char* format, va_list args) override {
+    callback_(user_data_, format, args);
+    return 0;
+  }
+
+ private:
+  ErrorCallback callback_;
+  void* user_data_;
+};
+}  // namespace
+
 // LINT.IfChange
 
 TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) {
@@ -56,14 +77,38 @@
   options->num_threads = num_threads;
 }
 
+TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetErrorReporter(
+    TFL_InterpreterOptions* options,
+    void (*reporter)(void* user_data, const char* format, va_list args),
+    void* user_data) {
+  options->error_reporter = reporter;
+  options->error_reporter_user_data = user_data;
+}
+
 TFL_Interpreter* TFL_NewInterpreter(
     const TFL_Model* model, const TFL_InterpreterOptions* optional_options) {
   if (!model || !model->impl) {
     return nullptr;
   }
 
+  std::unique_ptr<tflite::ErrorReporter> optional_error_reporter;
+  if (optional_options && optional_options->error_reporter != nullptr) {
+    optional_error_reporter.reset(
+        new CallbackErrorReporter(optional_options->error_reporter,
+                                  optional_options->error_reporter_user_data));
+  }
+
+  // TODO(b/111881878): Allow use of C API without pulling in all builtin ops.
   tflite::ops::builtin::BuiltinOpResolver resolver;
-  tflite::InterpreterBuilder builder(*model->impl, resolver);
+  if (optional_options) {
+    resolver.AddAll(optional_options->op_resolver);
+  }
+  tflite::ErrorReporter* error_reporter = optional_error_reporter
+                                              ? optional_error_reporter.get()
+                                              : tflite::DefaultErrorReporter();
+  tflite::InterpreterBuilder builder(model->impl->GetModel(), resolver,
+                                     error_reporter);
+
   std::unique_ptr<tflite::Interpreter> interpreter;
   if (builder(&interpreter) != kTfLiteOk) {
     return nullptr;
@@ -76,7 +121,8 @@
     }
   }
 
-  return new TFL_Interpreter{model->impl, std::move(interpreter)};
+  return new TFL_Interpreter{model->impl, std::move(optional_error_reporter),
+                             std::move(interpreter)};
 }
 
 void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; }
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h
index b429e76..f52ab8f 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api.h
@@ -15,6 +15,7 @@
 #ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_
 #define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_
 
+#include <stdarg.h>
 #include <stdint.h>
 
 // Eventually the various C APIs defined in context.h will be migrated into
@@ -52,8 +53,9 @@
 extern "C" {
 #endif  // __cplusplus
 
-typedef TfLiteTensor TFL_Tensor;
+typedef TfLiteRegistration TFL_Registration;
 typedef TfLiteStatus TFL_Status;
+typedef TfLiteTensor TFL_Tensor;
 typedef TfLiteType TFL_Type;
 
 // --------------------------------------------------------------------------
@@ -85,6 +87,17 @@
 TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetNumThreads(
     TFL_InterpreterOptions* options, int32_t num_threads);
 
+// Sets a custom error reporter for interpreter execution.
+//
+// * `reporter` takes the provided `user_data` object, as well as a C-style
+//   format string and arg list (see also vprintf).
+// * `user_data` is optional. If provided, it is owned by the client and must
+//   remain valid for the duration of the interpreter lifetime.
+TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetErrorReporter(
+    TFL_InterpreterOptions* options,
+    void (*reporter)(void* user_data, const char* format, va_list args),
+    void* user_data);
+
 // --------------------------------------------------------------------------
 // TFL_Interpreter provides inference from a provided model.
 typedef struct TFL_Interpreter TFL_Interpreter;
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
index c4dbc55..0f16595 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
@@ -26,6 +26,22 @@
   return interpreter->impl->ResetVariableTensorsToZero();
 }
 
+void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options,
+                                        TFL_BuiltinOperator op,
+                                        const TFL_Registration* registration,
+                                        int32_t min_version,
+                                        int32_t max_version) {
+  options->op_resolver.AddBuiltin(static_cast<tflite::BuiltinOperator>(op),
+                                  registration, min_version, max_version);
+}
+
+void TFL_InterpreterOptionsAddCustomOp(TFL_InterpreterOptions* options,
+                                       const char* name,
+                                       const TFL_Registration* registration,
+                                       int min_version, int max_version) {
+  options->op_resolver.AddCustom(name, registration, min_version, max_version);
+}
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
index b0ac258..b8de7b9 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
@@ -15,16 +15,41 @@
 #ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_
 #define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_
 
+#include "tensorflow/contrib/lite/builtin_ops.h"
 #include "tensorflow/contrib/lite/experimental/c/c_api.h"
 
 #ifdef __cplusplus
 extern "C" {
 #endif  // __cplusplus
 
+typedef TfLiteBuiltinOperator TFL_BuiltinOperator;
+
 // Resets all variable tensors to zero.
 TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensorsToZero(
     TFL_Interpreter* interpreter);
 
+// Adds an op registration for a builtin operator.
+//
+// NOTE: The interpreter will make a copy of `registration` internally, so the
+// caller should ensure that its contents (function pointers, etc...) remain
+// valid for the duration of the interpreter's lifetime. A common practice is
+// making the provided TFL_Registration instance static.
+void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options,
+                                        TFL_BuiltinOperator op,
+                                        const TFL_Registration* registration,
+                                        int min_version, int max_version);
+
+// Adds an op registration for a custom operator.
+//
+// NOTE: The interpreter will make a copy of `registration` internally, so the
+// caller should ensure that its contents (function pointers, etc...) remain
+// valid for the duration of the interpreter's lifetime. A common practice is
+// making the provided TFL_Registration instance static.
+void TFL_InterpreterOptionsAddCustomOp(TFL_InterpreterOptions* options,
+                                       const char* name,
+                                       const TFL_Registration* registration,
+                                       int min_version, int max_version);
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
index db6e525..d86ad00 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
@@ -16,25 +16,40 @@
 #include "tensorflow/contrib/lite/experimental/c/c_api_experimental.h"
 
 #include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/builtin_ops.h"
 #include "tensorflow/contrib/lite/experimental/c/c_api.h"
 #include "tensorflow/contrib/lite/testing/util.h"
 
 namespace {
 
+TfLiteRegistration* GetDummyRegistration() {
+  static TfLiteRegistration registration = {
+      .init = nullptr,
+      .free = nullptr,
+      .prepare = nullptr,
+      .invoke = [](TfLiteContext*, TfLiteNode*) { return kTfLiteOk; },
+  };
+  return &registration;
+}
+
 TEST(CApiExperimentalSimple, Smoke) {
   TFL_Model* model = TFL_NewModelFromFile(
       "tensorflow/contrib/lite/testdata/add.bin");
   ASSERT_NE(model, nullptr);
 
-  TFL_Interpreter* interpreter =
-      TFL_NewInterpreter(model, /*optional_options=*/nullptr);
+  TFL_InterpreterOptions* options = TFL_NewInterpreterOptions();
+  TFL_InterpreterOptionsAddBuiltinOp(options, kTfLiteBuiltinAdd,
+                                     GetDummyRegistration(), 1, 1);
+
+  TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options);
   ASSERT_NE(interpreter, nullptr);
   ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk);
-
   EXPECT_EQ(TFL_InterpreterResetVariableTensorsToZero(interpreter), kTfLiteOk);
+  EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteOk);
 
-  TFL_DeleteModel(model);
   TFL_DeleteInterpreter(interpreter);
+  TFL_DeleteInterpreterOptions(options);
+  TFL_DeleteModel(model);
 }
 
 }  // namespace
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
index 60c2e4e..da3af3c 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
@@ -19,9 +19,13 @@
 
 #include "tensorflow/contrib/lite/interpreter.h"
 #include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
 
 // Internal structures used by the C API. These are likely to change and should
 // not be depended on.
+//
+// NOTE: This header does not follow C conventions and does not define a C API.
+// It is effectively an (internal) implementation detail of the C API.
 
 struct TFL_Model {
   // Sharing is safe as FlatBufferModel is const.
@@ -33,12 +37,24 @@
     kDefaultNumThreads = -1,
   };
   int num_threads = kDefaultNumThreads;
+
+  tflite::MutableOpResolver op_resolver;
+
+  void (*error_reporter)(void* user_data, const char* format,
+                         va_list args) = nullptr;
+  void* error_reporter_user_data = nullptr;
 };
 
 struct TFL_Interpreter {
   // Taking a reference to the (const) model data avoids lifetime-related issues
   // and complexity with the TFL_Model's existence.
   std::shared_ptr<const tflite::FlatBufferModel> model;
+
+  // The interpreter does not take ownership of the provided ErrorReporter
+  // instance, so we ensure its validity here. Note that the interpreter may use
+  // the reporter in its destructor, so it should be declared first.
+  std::unique_ptr<tflite::ErrorReporter> optional_error_reporter;
+
   std::unique_ptr<tflite::Interpreter> impl;
 };
 
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
index 649dac8..48a3714 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
@@ -85,6 +85,37 @@
   TFL_DeleteInterpreter(interpreter);
 }
 
+TEST(CApiSimple, ErrorReporter) {
+  TFL_Model* model = TFL_NewModelFromFile(
+      "tensorflow/contrib/lite/testdata/add.bin");
+  TFL_InterpreterOptions* options = TFL_NewInterpreterOptions();
+
+  // Install a custom error reporter into the interpreter by way of options.
+  tflite::TestErrorReporter reporter;
+  TFL_InterpreterOptionsSetErrorReporter(
+      options,
+      [](void* user_data, const char* format, va_list args) {
+        reinterpret_cast<tflite::TestErrorReporter*>(user_data)->Report(format,
+                                                                        args);
+      },
+      &reporter);
+  TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options);
+
+  // The options/model can be deleted immediately after interpreter creation.
+  TFL_DeleteInterpreterOptions(options);
+  TFL_DeleteModel(model);
+
+  // Invoke the interpreter before tensor allocation.
+  EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteError);
+
+  // The error should propagate to the custom error reporter.
+  EXPECT_EQ(reporter.error_messages(),
+            "Invoke called on model that is not ready.");
+  EXPECT_EQ(reporter.num_calls(), 1);
+
+  TFL_DeleteInterpreter(interpreter);
+}
+
 }  // namespace
 
 int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/g3doc/_book.yaml b/tensorflow/contrib/lite/g3doc/_book.yaml
index 1dffe30..beaa5c4 100644
--- a/tensorflow/contrib/lite/g3doc/_book.yaml
+++ b/tensorflow/contrib/lite/g3doc/_book.yaml
@@ -5,7 +5,7 @@
 # Dropdown menu
 - name: Ecosystem
   path: /ecosystem
-  is_default: True
+  is_default: true
   menu:
   - include: /ecosystem/_menu_toc.yaml
   lower_tabs:
@@ -14,46 +14,49 @@
     - name: Guide
       contents:
       - title: Overview
-        path: /mobile/overview
-      - title: Developer Guide
-        path: /mobile/devguide
-      - title: Android Demo App
-        path: /mobile/demo_android
-      - title: iOS Demo App
-        path: /mobile/demo_ios
+        path: /lite/overview
+      - title: Developer guide
+        path: /lite/devguide
+      - title: Android demo app
+        path: /lite/demo_android
+      - title: iOS demo app
+        path: /lite/demo_ios
       - title: Performance
-        path: /mobile/performance
-      - break: True
+        path: /lite/performance
+      - break: true
       - title: TensorFlow Lite APIs
-        path: /mobile/apis
+        path: /lite/apis
       - title: Custom operators
-        path: /mobile/custom_operators
-      - title: TensorFlow Lite Ops Versioning
-        path: /mobile/ops_versioning
-      - title: TensorFlow Lite Compatibility Guide
-        path: /mobile/tf_ops_compatibility
-      - title: List of Hosted Models
-        path: /mobile/models
+        path: /lite/custom_operators
+      - title: TensorFlow Lite ops versioning
+        path: /lite/ops_versioning
+      - title: TensorFlow Lite compatibility guide
+        path: /lite/tf_ops_compatibility
+      - title: List of hosted models
+        path: /lite/models
       - title: TensorFlow Lite for iOS
-        path: /mobile/ios
+        path: /lite/ios
       - title: TensorFlow Lite for Raspberry Pi
-        path: /mobile/rpi
+        path: /lite/rpi
 
-      - heading: TF Mobile
+      - title: TF Mobile
+        style: accordion
         status: deprecated
-      - title: Overview
-        path: /mobile/tfmobile/
-      - title: Building TensorFlow on Android
-        path: /mobile/tfmobile/android_build
-      - title: Building TensorFlow on IOS
-        path: /mobile/tfmobile/ios_build
-      - title: Integrating TensorFlow libraries
-        path: /mobile/tfmobile/linking_libs
-      - title: Preparing models for mobile deployment
-        path: /mobile/tfmobile/prepare_models
-      - title: Optimizing for mobile
-        path: /mobile/tfmobile/optimizing
+        section:
+        - title: Overview
+          path: /lite/tfmobile/
+        - title: Building TensorFlow on Android
+          path: /lite/tfmobile/android_build
+        - title: Building TensorFlow on IOS
+          path: /lite/tfmobile/ios_build
+        - title: Integrating TensorFlow libraries
+          path: /lite/tfmobile/linking_libs
+        - title: Preparing models for mobile deployment
+          path: /lite/tfmobile/prepare_models
+        - title: Optimizing for mobile
+          path: /lite/tfmobile/optimizing
 
     - name: API
       contents:
-      - include: /mobile/api_docs/python/_toc.yaml
+      - title: API
+        path: /api_docs/python/tf/contrib/lite
diff --git a/tensorflow/contrib/lite/g3doc/_index.yaml b/tensorflow/contrib/lite/g3doc/_index.yaml
index b3f21e2..bc66cc5 100644
--- a/tensorflow/contrib/lite/g3doc/_index.yaml
+++ b/tensorflow/contrib/lite/g3doc/_index.yaml
@@ -1,60 +1,209 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
+project_path: /lite/_project.yaml
+book_path: /lite/_book.yaml
 description: <!--no description-->
 landing_page:
+  custom_css_path: /site-assets/css/style.css
   rows:
-  - heading: TensorFlow Lite is a lightweight solution for mobile and embedded devices.
+  - heading: TensorFlow Lite is for mobile and embedded devices.
+    description: >
+      <p style="max-width: 75%;">
+        TensorFlow Lite is the official solution for running machine learning
+        models on mobile and embedded devices. It enables on&#8209;device machine
+        learning inference with low latency and a small binary size on Android,
+        iOS, and other operating systems.
+      </p>
+      <style>
+      .tfo-landing-row-heading {
+        padding-top: 0 !important;
+      }
+      .tfo-landing-row-heading h2 {
+        margin-top: 0 !important;
+      }
+      .tfo-landing-row-heading-list ol, .tfo-landing-row-heading-list ul {
+        margin-top: 0;
+      }
+      </style>
+
+  - classname: tfo-landing-row-heading tfo-landing-row-heading-list
+    heading: Many benefits
+    description: >
+      On-device ML inference is difficult because of the many constraints—TensorFlow Lite can solve these:
     items:
-    - classname: devsite-landing-row-50
+    - list:
+      - heading: Performance
+        description: >
+          TF Lite is fast with no noticeable accuracy loss—see the <a href="./performance">metrics</a>.
+        icon:
+          icon_name: lens
+          foreground: theme
+      - heading: Portability
+        description: >
+          <a href="https://developer.android.com/ndk/guides/neuralnetworks/" class="external">Android</a>,
+          iOS, and more specialized IoT devices.
+        icon:
+          icon_name: lens
+          foreground: theme
+    - list:
+      - heading: Low latency
+        description: >
+          Optimized float- and fixed-point CPU kernels, op&#8209;fusing, and more.
+        icon:
+          icon_name: lens
+          foreground: theme
+      - heading: Acceleration
+        description: >
+          Integration with GPU and internal/external accelerators.
+        icon:
+          icon_name: lens
+          foreground: theme
+    - list:
+      - heading: Small model size
+        description: >
+          Controlled dependencies, <a href="https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3" class="external">quantization</a>,
+          and op&nbsp;registration.
+        icon:
+          icon_name: lens
+          foreground: theme
+      - heading: Tooling
+        description: >
+          Conversion, compression, benchmarking, power-consumption, and more.
+        icon:
+          icon_name: lens
+          foreground: theme
+
+  - classname: devsite-landing-row-logos tfo-landing-row-heading
+    heading: Companies using TensorFlow Lite
+    items:
+    - custom_image:
+        path: ./images/landing-page/photos_logo.png
+      path: https://www.photos.google.com
+    - custom_image:
+        path: ./images/landing-page/gboard_logo.png
+      path: https://play.google.com/store/apps/details?id=com.google.android.inputmethod.latin&hl=en_US
+    - custom_image:
+        path: ./images/landing-page/gmail_logo.png
+      path: https://www.google.com/gmail/
+    - custom_image:
+        path: ./images/landing-page/assistant_logo.png
+      path: https://assistant.google.com/
+
+  - classname: devsite-landing-row-logos
+    items:
+    - custom_image:
+        path: ./images/landing-page/vsco_logo.png
+      path: https://vsco.co
+    - custom_image:
+        path: ./images/landing-page/shazam_logo.png
+      path: https://www.shazam.com/
+    - custom_image:
+        path: ./images/landing-page/nest_logo.png
+      path: https://nest.com/    
+    - custom_image:
+        path: ./images/landing-page/loseit_logo.png
+      path: https://www.loseit.com/
+
+  - classname: devsite-landing-row-no-image-background devsite-landing-row-67
+    background: grey
+    items:
+    - description: >
+        <em>“TensorFlow Lite helped us introduce machine learning and AI into our
+        app in an easy and streamlined way. We could reduce the size of our
+        models while keeping the accuracy high. This helped us create an amazing
+        fishing experience for our users by allowing them to identify any fish
+        species with just a photo.”</em>
+      image_path: ./images/landing-page/fishbrain_logo_big.png
+
+  - heading: How it works
+    items:
+    - heading: Build
+      icon:
+        icon_name: build
       description: >
-        TensorFlow Lite is TensorFlow’s lightweight solution for mobile and
-        embedded devices. It enables on-device machine learning inference with
-        low latency and a small binary size. TensorFlow Lite also supports
-        hardware acceleration with the
-        <a href='https://developer.android.com/ndk/guides/neuralnetworks/index.html'>Android Neural Networks API</a>.
-      list:
-      - heading: Key point 1
-        description: >
-          [high-level overview]
-        icon:
-          icon_name: chevron_right
-          foreground: theme
-          background: grey
-      - heading: Key point 2
-        description: >
-          [high-level overview]
-        icon:
-          icon_name: chevron_right
-          foreground: theme
-          background: grey
-      - heading: Key point 3
-        description: >
-          [high-level overview]
-        icon:
-          icon_name: chevron_right
-          foreground: theme
-          background: grey
-      code_block: |
-        <pre class = "prettyprint">
-        $ toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
-               --input_format=TENSORFLOW_GRAPHDEF \
-               --output_format=TFLITE \
-               --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
-               --inference_type=FLOAT \
-               --input_type=FLOAT \
-               --input_arrays=input \
-               --output_arrays=MobilenetV1/Predictions/Reshape_1 \
-               --input_shapes=1,224,224,3
-        </pre>
+        Build a new model or retrain an existing one, such as using transfer learning.
+      buttons:
+      - label: Read the developer guide
+        path: /lite/devguide
+        classname: button button-primary tfo-button-primary
+    - heading: Convert
+      icon:
+        icon_name: autorenew
+      description: >
+        Convert a TensorFlow model into a compressed flat buffer with the
+        TensorFlow Lite Optimizing Converter (TOCO).
+      buttons:
+      - label: Read the TOCO guide
+        path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/python_api.md
+        classname: button button-primary tfo-button-primary
+    - heading: Deploy
+      icon:
+        icon_name: bolt
+      description: >
+        Take the compressed <code>.tflite</code> file and load it into a mobile
+        or embedded device.<br/>
+        See the <a href="#build-your-first-tensorflow-lite-app">tutorials below</a> to build an app.
+
+  - heading: Build your first TensorFlow Lite app
+    background: grey
+    items:
+    - classname: tfo-landing-row-item-inset-white
+      heading: Get started
+      description: >
+        <ul>
+          <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/" class="external">TensorFlow for Poets</a></li>
+          <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-tflite/" class="external">TensorFlow for Poets 2: Android</a></li>
+          <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-ios/" class="external">TensorFlow for Poets 2: iOS </a></li>
+          <li>Intermediate: <a href="https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193" class="external">Object detection tutorial</a>
+        </ul>
+    - classname: tfo-landing-row-item-inset-white
+      heading: Share your TensorFlow Lite story
+      description: >
+        We love to hear what you're working on—it may even get highlighted on
+        our social media! <a href="https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss" class="external">Tell us</a>.
+
+  - classname: devsite-landing-row-no-image-background devsite-landing-row-67
+    items:
+    - description: >
+        <p>
+          <em>“The release of TensorFlow Lite has allowed us to deploy an engaging
+          real-time experience to our users that eliminates the requirement
+          for a data connection. TensorFlow Lite’s ability to compress and
+          optimize the TensorFlow graph for mobile deployment has been
+          transformative in expanding the capabilities of Snap It.</em>
+        </p>
+        <p>
+          <em>Through TensorFlow Lite, our users can now enjoy a state of the
+          art, computer-vision-based food logging experience without worrying
+          about signal strength. We look forward to future collaborations
+          with the TensorFlow Lite team.”</em>
+        </p>
+      image_path: ./images/landing-page/loseit_logo_big.png
 
   - classname: devsite-landing-row-cards
+    background: grey
+    heading: Updates
     items:
+    - heading: Introducing the Model Optimization Toolkit
+      image_path: /ecosystem/images/tf-logo-card-16x9.png
+      path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3
+      buttons:
+      - label: Read on TensorFlow blog
+        path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3
+    - heading: East Africa Cassava App
+      image_path: ./images/landing-page/detect_crop_disease_in_africa.png
+      path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5
+      buttons:
+      - label: Read more
+        path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5
     - heading: Using TensorFlow Lite on Android
       image_path: /ecosystem/images/tf-logo-card-16x9.png
       path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d
       buttons:
       - label: Read on TensorFlow blog
         path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d
+
+  - classname: devsite-landing-row-cards
+    background: grey
+    items:
     - heading: TensorFlow Lite at the Dev Summit
       youtube_id: FAMfy7izB6A
       buttons:
@@ -66,3 +215,4 @@
       buttons:
       - label: View on GitHub
         path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite
+    - classname: devsite-landing-row-item-hidden
diff --git a/tensorflow/contrib/lite/g3doc/_project.yaml b/tensorflow/contrib/lite/g3doc/_project.yaml
index b396665..3ce6986 100644
--- a/tensorflow/contrib/lite/g3doc/_project.yaml
+++ b/tensorflow/contrib/lite/g3doc/_project.yaml
@@ -1,10 +1,10 @@
 name: TensorFlow Lite
-breadcrumb_name: Mobile
-home_url: /mobile/
+breadcrumb_name: TensorFlow Lite
+home_url: /lite/
 parent_project_metadata_path: /_project.yaml
 description: >
   TensorFlow Lite is a lightweight solution for mobile and embedded devices.
-use_site_branding: True
-hide_from_products_list: True
+use_site_branding: true
+hide_from_products_list: true
 content_license: cc3-apache2
 buganizer_id: 316308
diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml b/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml
deleted file mode 100644
index 1e1c44c..0000000
--- a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-# Automatically generated file; please do not edit
-toc:
-  - title: TensorFlow Lite
-    section:
-    - title: Overview
-      path: /mobile/api_docs/python/
diff --git a/tensorflow/contrib/lite/g3doc/devguide.md b/tensorflow/contrib/lite/g3doc/devguide.md
index 90e7915..0eed516 100644
--- a/tensorflow/contrib/lite/g3doc/devguide.md
+++ b/tensorflow/contrib/lite/g3doc/devguide.md
@@ -1,5 +1,4 @@
-
-# Developer Guide
+# TF Lite Developer Guide
 
 Using a TensorFlow Lite model in your mobile app requires multiple
 considerations: you must choose a pre-trained or custom model, convert the model
@@ -55,7 +54,7 @@
 ### Train a custom model
 
 A developer may choose to train a custom model using Tensorflow (see the
-[TensorFlow tutorials](../../tutorials/) for examples of building and training
+[TensorFlow tutorials](../tutorials/) for examples of building and training
 models). If you have already written a model, the first step is to export this
 to a `tf.GraphDef` file. This is required because some formats do not store the
 model structure outside the code, and we must communicate with other parts of the
@@ -205,7 +204,7 @@
 [on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app).
 You can also download a
 [prebuilt APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk).
-See the <a href="../demo_android.md">Android demo</a> guide for details.
+See the <a href="./demo_android.md">Android demo</a> guide for details.
 
 The <a href="./android_build.md">Android mobile</a> guide has instructions for
 installing TensorFlow on Android and setting up `bazel` and Android Studio.
@@ -214,7 +213,7 @@
 
 To integrate a TensorFlow model in an iOS app, see the
 [TensorFlow Lite for iOS](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md)
-guide and <a href="../demo_ios.md">iOS demo</a> guide.
+guide and <a href="./demo_ios.md">iOS demo</a> guide.
 
 #### Core ML support
 
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png
new file mode 100644
index 0000000..ced0872
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png b/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png
new file mode 100644
index 0000000..45b3b4f
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png
new file mode 100644
index 0000000..bc1bf6e
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png
new file mode 100644
index 0000000..d76fca8
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png
new file mode 100644
index 0000000..f1a93ab
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png
new file mode 100644
index 0000000..21aa2c8
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png
new file mode 100644
index 0000000..b6b3d14
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png
new file mode 100644
index 0000000..b3e46d4
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png
new file mode 100644
index 0000000..35bfd97
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png
new file mode 100644
index 0000000..4333426
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png
new file mode 100644
index 0000000..6ec412c
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png
new file mode 100644
index 0000000..f408f90
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md
index a83d2c8..3b9fcca 100644
--- a/tensorflow/contrib/lite/g3doc/ios.md
+++ b/tensorflow/contrib/lite/g3doc/ios.md
@@ -1,5 +1,10 @@
 
-# TensorFlow Lite for iOS
+# Build TensorFlow Lite for iOS
+
+This document describes how to build TensorFlow Lite iOS library. If you just
+want to use it, the easiest way is using the TensorFlow Lite CocoaPod releases.
+See [TensorFlow Lite iOS Demo](demo_ios.md) for examples.
+
 
 ## Building
 
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index 8660d29..b0dfb0f 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -866,6 +866,17 @@
 }
 ```
 
+**ZEROS_LIKE**
+
+```
+Inputs {
+  0: a tensor
+}
+Outputs {
+  0: A tensor of the same shape and type as x but filled with zeros
+}
+```
+
 And these are TensorFlow Lite operations that are present but not ready for
 custom models yet:
 
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
index c7cdee0..b0f32a8 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
@@ -93,7 +93,7 @@
 guide you through the basics here.
 
 - First, follow our instructions for
-  <a href="http://www.tensorflow.org/install/install_sources">installing from sources</a>.
+  <a href="http://www.tensorflow.org/install/source">installing from sources</a>.
   This will also guide you through installing Bazel and cloning the
   TensorFlow code.
 
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
index d003bb2..49ad35d 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
@@ -4,7 +4,7 @@
 TensorFlow was designed to be a good deep learning solution for mobile
 platforms. Currently we have two solutions for deploying machine learning
 applications on mobile and embedded devices: TensorFlow for Mobile and
-<a href="../index.md">TensorFlow Lite</a>.
+<a href="../../lite">TensorFlow Lite</a>.
 
 ## TensorFlow Lite versus TensorFlow Mobile
 
diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD
index 781289c..bb0be04 100644
--- a/tensorflow/contrib/lite/java/ovic/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/BUILD
@@ -44,6 +44,7 @@
 android_library(
     name = "ovicbenchmarkerlib",
     srcs = [
+        "src/main/java/org/tensorflow/ovic/OvicBenchmarker.java",
         "src/main/java/org/tensorflow/ovic/OvicClassifier.java",
         "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
     ],
diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/contrib/lite/java/ovic/README.md
index 2634934..df77bfa 100644
--- a/tensorflow/contrib/lite/java/ovic/README.md
+++ b/tensorflow/contrib/lite/java/ovic/README.md
@@ -4,7 +4,7 @@
 
 ## Pre-requisite
 
-Follow the steps [here](https://www.tensorflow.org/mobile/tflite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK.
+Follow the steps [here](https://www.tensorflow.org/lite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK.
 
 ## Test the benchmarker:
 
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
index a8d751a..b2e3a9bd 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
@@ -6,7 +6,6 @@
 android_binary(
     name = "ovic_benchmarker_binary",
     srcs = [
-        "OvicBenchmarker.java",
         "OvicBenchmarkerActivity.java",
     ],
     assets = [
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
index 59457c3..4adf94a 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
@@ -34,8 +34,10 @@
 import java.nio.MappedByteBuffer;
 import java.nio.channels.FileChannel;
 import java.text.DecimalFormat;
+import org.tensorflow.ovic.OvicBenchmarker;
 import org.tensorflow.ovic.OvicSingleImageResult;
 
+
 /** Class that benchmark image classifier models. */
 public class OvicBenchmarkerActivity extends Activity {
   /** Tag for the {@link Log}. */
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
similarity index 97%
rename from tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java
rename to tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
index 113ab74..4cda258 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
@@ -12,7 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-package ovic.demo.app;
+package org.tensorflow.ovic;
 
 import android.graphics.Bitmap;
 import android.os.SystemClock;
@@ -22,8 +22,6 @@
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.nio.MappedByteBuffer;
-import org.tensorflow.ovic.OvicClassifier;
-import org.tensorflow.ovic.OvicSingleImageResult;
 
 /**
  * Class that benchmarks image classifier models.
diff --git a/tensorflow/contrib/lite/kernels/Android.bp b/tensorflow/contrib/lite/kernels/Android.bp
index 01382b7..ed45952 100644
--- a/tensorflow/contrib/lite/kernels/Android.bp
+++ b/tensorflow/contrib/lite/kernels/Android.bp
@@ -97,6 +97,7 @@
         "unidirectional_sequence_lstm.cc",
         "unidirectional_sequence_rnn.cc",
         "unpack.cc",
+        "zeros_like.cc",
 	"internal/kernel_utils.cc",
         "internal/tensor_utils.cc",
         "internal/quantization_util.cc",
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 40f28ae..daaf671 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -223,6 +223,7 @@
         "unidirectional_sequence_lstm.cc",
         "unidirectional_sequence_rnn.cc",
         "unpack.cc",
+        "zeros_like.cc",
     ],
     hdrs = [
     ],
@@ -508,6 +509,7 @@
         ":builtin_ops",
         "//tensorflow/contrib/lite:framework",
         "//tensorflow/contrib/lite/kernels:test_util",
+        "@com_google_absl//absl/memory",
         "@com_google_googletest//:gtest",
     ],
 )
@@ -1284,6 +1286,20 @@
     ],
 )
 
+tf_cc_test(
+    name = "zeros_like_test",
+    size = "small",
+    srcs = ["zeros_like_test.cc"],
+    tags = ["tflite_not_portable_ios"],
+    deps = [
+        ":builtin_ops",
+        "//tensorflow/contrib/lite:builtin_op_data",
+        "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/kernels:test_util",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
 filegroup(
     name = "all_files",
     srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 3e1ce60..798ee84 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -184,17 +184,7 @@
                          const Dims<4>&, const float*, const Dims<4>&, int, int,
                          int, int, int, int, int, float, float, float*,
                          const Dims<4>&);
-  KernelType effective_kernel_type;
-  // TODO(suharshs): Currently only the reference implementation supports
-  // dilations.
-  if ((params->dilation_width_factor != 1) ||
-      (params->dilation_height_factor != 1)) {
-    effective_kernel_type = kReference;
-  } else {
-    effective_kernel_type = kernel_type;
-  }
-
-  if (effective_kernel_type == kReference) {
+  if (kernel_type == kReference) {
     depthwise_conv = &reference_ops::DepthwiseConv;
   } else {
     depthwise_conv = &optimized_ops::DepthwiseConv;
@@ -224,17 +214,7 @@
                          int, int, int, int, int, int, int, int32, int32, int,
                          int32, int32, uint8*, const Dims<4>&);
 
-  KernelType effective_kernel_type;
-  // TODO(suharshs): Currently only the reference implementation supports
-  // dilations.
-  if ((params->dilation_width_factor != 1) ||
-      (params->dilation_height_factor != 1)) {
-    effective_kernel_type = kReference;
-  } else {
-    effective_kernel_type = kernel_type;
-  }
-
-  if (effective_kernel_type == kReference) {
+  if (kernel_type == kReference) {
     depthwise_conv = &reference_ops::DepthwiseConv;
   } else {
     depthwise_conv = &optimized_ops::DepthwiseConv;
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
index 2af26ab..4a33a03 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
@@ -14,12 +14,24 @@
 ==============================================================================*/
 #include <cstdarg>
 #include <gtest/gtest.h>
+#include "absl/memory/memory.h"
 #include "tensorflow/contrib/lite/interpreter.h"
 #include "tensorflow/contrib/lite/kernels/register.h"
 #include "tensorflow/contrib/lite/kernels/test_util.h"
 #include "tensorflow/contrib/lite/model.h"
 
 namespace tflite {
+
+namespace ops {
+namespace builtin {
+
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF();
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT();
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_NEON_OPT();
+
+}  // namespace builtin
+}  // namespace ops
+
 namespace {
 
 using ::testing::ElementsAreArray;
@@ -28,9 +40,11 @@
  public:
   // TODO(ahentz): Also test different activation types, bias, padding types,
   // stride values.
-  BaseDepthwiseConvolutionOpModel(const TensorData& input,
+  BaseDepthwiseConvolutionOpModel(TfLiteRegistration* registration,
+                                  const TensorData& input,
                                   const TensorData& filter,
                                   const TensorData& output,
+                                  Padding padding_type,
                                   int dilation_factor = 1) {
     input_ = AddInput(input);
     filter_ = AddInput(filter);
@@ -56,11 +70,14 @@
     SetBuiltinOp(
         BuiltinOperator_DEPTHWISE_CONV_2D,
         BuiltinOptions_DepthwiseConv2DOptions,
-        CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
+        CreateDepthwiseConv2DOptions(builder_, padding_type, 1, 1, depth_mul,
                                      ActivationFunctionType_NONE,
                                      dilation_factor, dilation_factor)
             .Union());
 
+    resolver_ = absl::make_unique<SingleOpResolver>(
+        BuiltinOperator_DEPTHWISE_CONV_2D, registration);
+
     BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
   }
 
@@ -86,10 +103,25 @@
   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
 };
 
-TEST(DepthwiseConvolutionOpTest, SimpleTest) {
-  DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
+const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
+    {"Reference", ops::builtin::Register_DEPTHWISE_CONVOLUTION_REF()},
+    {"GenericOptimized",
+     ops::builtin::Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT()},
+    {"NeonOptimized", ops::builtin::Register_DEPTHWISE_CONVOLUTION_NEON_OPT()},
+});
+
+class DepthwiseConvolutionOpTest : public SingleOpTest {
+ protected:
+  const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+    return *kKernelMap;
+  }
+};
+
+TEST_P(DepthwiseConvolutionOpTest, SimpleTest) {
+  DepthwiseConvolutionOpModel m(GetRegistration(),
+                                {TensorType_FLOAT32, {1, 3, 2, 2}},
                                 {TensorType_FLOAT32, {1, 2, 2, 4}},
-                                {TensorType_FLOAT32, {}});
+                                {TensorType_FLOAT32, {}}, Padding_VALID);
 
   m.SetInput({
       1, 2, 7, 8,    // column 1
@@ -112,7 +144,7 @@
                              }));
 }
 
-TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
+TEST_P(DepthwiseConvolutionOpTest, SimpleDilatedTestPaddingValid) {
   const int depth = 1;
   const int image_width = 9;
   const int image_height = 9;
@@ -121,10 +153,11 @@
   const int filter_count = 1;
   const int dilation_factor = 3;
   DepthwiseConvolutionOpModel m(
+      GetRegistration(),
       {TensorType_FLOAT32,
        {image_batch_count, image_height, image_width, depth}},
       {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
-      {TensorType_FLOAT32, {}}, dilation_factor);
+      {TensorType_FLOAT32, {}}, Padding_VALID, dilation_factor);
 
   // The image matrix is:
   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
@@ -164,6 +197,41 @@
   EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
 }
 
+TEST_P(DepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) {
+  const int depth = 1;
+  const int image_width = 3;
+  const int image_height = 3;
+  const int image_batch_count = 1;
+  const int filter_size = 2;
+  const int filter_count = 1;
+  const int dilation_factor = 2;
+  DepthwiseConvolutionOpModel m(
+      GetRegistration(),
+      {TensorType_FLOAT32,
+       {image_batch_count, image_height, image_width, depth}},
+      {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+      {TensorType_FLOAT32, {}}, Padding_SAME, dilation_factor);
+
+  // The image matrix is:
+  // | 1 | 1 | 1 |
+  // | 1 | 1 | 1 |
+  // | 1 | 1 | 1 |
+  m.SetInput({1, 1, 1, 1, 1, 1, 1, 1, 1});
+  // The filter matrix is:
+  // | 1 | 2 |
+  // | 3 | 4 |
+  m.SetFilter({1, 2, 3, 4});
+  // No bias for this test.
+  m.SetBias({0});
+  m.Invoke();
+
+  // Output:
+  // | 4 | 7 | 3 |
+  // | 6 |10 | 4 |
+  // | 2 | 3 | 1 |
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1}));
+}
+
 class QuantizedDepthwiseConvolutionOpModel
     : public BaseDepthwiseConvolutionOpModel {
  public:
@@ -188,13 +256,20 @@
   }
 };
 
+class QuantizedDepthwiseConvolutionOpTest : public SingleOpTest {
+ protected:
+  const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+    return *kKernelMap;
+  }
+};
+
 // In this test we set the input and output scales so that the results match
 // exactly the 'non-quantized' version.
-TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
   QuantizedDepthwiseConvolutionOpModel m(
-      {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
+      GetRegistration(), {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
       {TensorType_UINT8, {1, 2, 2, 4}, -63.5, 64},
-      {TensorType_UINT8, {}, -127, 128});
+      {TensorType_UINT8, {}, -127, 128}, Padding_VALID);
 
   m.SetInput({
       1, 2, 7, 8,    // column 1
@@ -224,15 +299,16 @@
                              }));
 }
 
-TEST(QuantizedDepthwiseConvolutionOpTest,
-     SimpleTestQuantizedFilterMultiplierGreaterThan1) {
+TEST_P(QuantizedDepthwiseConvolutionOpTest,
+       SimpleTestQuantizedFilterMultiplierGreaterThan1) {
   QuantizedDepthwiseConvolutionOpModel quant_op(
-      {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
+      GetRegistration(), {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
       {TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128},
-      {TensorType_UINT8, {}, -127, 128});
-  DepthwiseConvolutionOpModel float_op({TensorType_FLOAT32, {1, 3, 2, 2}},
+      {TensorType_UINT8, {}, -127, 128}, Padding_VALID);
+  DepthwiseConvolutionOpModel float_op(GetRegistration(),
+                                       {TensorType_FLOAT32, {1, 3, 2, 2}},
                                        {TensorType_FLOAT32, {1, 2, 2, 4}},
-                                       {TensorType_FLOAT32, {}});
+                                       {TensorType_FLOAT32, {}}, Padding_VALID);
 
   std::initializer_list<float> input = {
       1, 2, 7,  8,   // column 1
@@ -261,7 +337,7 @@
               ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
 }
 
-TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingValid) {
   const int depth = 1;
   const int image_width = 9;
   const int image_height = 9;
@@ -270,6 +346,7 @@
   const int filter_count = 1;
   const int dilation_factor = 3;
   QuantizedDepthwiseConvolutionOpModel m(
+      GetRegistration(),
       {TensorType_UINT8,
        {image_batch_count, image_height, image_width, depth},
        0,
@@ -278,7 +355,7 @@
        {depth, filter_size, filter_size, filter_count},
        0,
        255},
-      {TensorType_UINT8, {}, 0, 255}, dilation_factor);
+      {TensorType_UINT8, {}, 0, 255}, Padding_VALID, dilation_factor);
 
   // The image matrix is:
   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
@@ -319,6 +396,55 @@
               ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
 }
 
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) {
+  const int depth = 1;
+  const int image_width = 3;
+  const int image_height = 3;
+  const int image_batch_count = 1;
+  const int filter_size = 2;
+  const int filter_count = 1;
+  const int dilation_factor = 2;
+  QuantizedDepthwiseConvolutionOpModel m(
+      GetRegistration(),
+      {TensorType_UINT8,
+       {image_batch_count, image_height, image_width, depth},
+       0,
+       255},
+      {TensorType_UINT8,
+       {depth, filter_size, filter_size, filter_count},
+       0,
+       255},
+      {TensorType_UINT8, {}, 0, 255}, Padding_SAME, dilation_factor);
+
+  // The image matrix is:
+  // | 1 | 1 | 1 |
+  // | 1 | 1 | 1 |
+  // | 1 | 1 | 1 |
+  m.SetInput({1, 1, 1, 1, 1, 1, 1, 1, 1});
+  // The filter matrix is:
+  // | 1 | 2 |
+  // | 3 | 4 |
+  m.SetFilter({1, 2, 3, 4});
+  // No bias for this test.
+  m.SetBias({0});
+  m.Invoke();
+
+  // Output:
+  // | 4 | 7 | 3 |
+  // | 6 |10 | 4 |
+  // | 2 | 3 | 1 |
+  EXPECT_THAT(m.GetDequantizedOutput(),
+              ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1}));
+}
+
+INSTANTIATE_TEST_CASE_P(
+    DepthwiseConvolutionOpTest, DepthwiseConvolutionOpTest,
+    ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
+INSTANTIATE_TEST_CASE_P(
+    QuantizedDepthwiseConvolutionOpTest, QuantizedDepthwiseConvolutionOpTest,
+    ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
 }  // namespace
 }  // namespace tflite
 
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index a6fd4ac..195474e 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -43,6 +43,7 @@
         "compatibility.h",
         "types.h",
     ],
+    deps = ["@com_google_absl//absl/base:core_headers"],
 )
 
 config_setting(
@@ -458,7 +459,7 @@
     ],
     copts = NEON_FLAGS_IF_APPLICABLE,
     deps = [
-        "//tensorflow/contrib/lite/kernels:activation_functor",
+        "@com_google_absl//absl/base:core_headers",
         "//tensorflow/contrib/lite/c:c_api_internal",
         "@arm_neon_2_x86_sse",
         "@gemmlowp",
diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
index 844ee6a..7600b26 100644
--- a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
@@ -17,6 +17,7 @@
 #include <vector>
 
 #include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
 #include "tensorflow/contrib/lite/kernels/internal/test_util.h"
 #include "tensorflow/contrib/lite/kernels/internal/types.h"
 
@@ -28,23 +29,29 @@
 namespace {
 
 // Runs the DepthwiseConv and compares against the reference implementation.
-template <FusedActivationFunctionType Ac>
 void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims,
                           const float* filter_data, const Dims<4>& filter_dims,
                           const float* bias_data, const Dims<4>& bias_dims,
-                          int stride, int pad_width, int pad_height,
-                          int depth_multiplier, const Dims<4>& output_dims) {
+                          int stride, int dilation_width_factor,
+                          int dilation_height_factor, int pad_width,
+                          int pad_height, int depth_multiplier,
+                          float output_activation_min,
+                          float output_activation_max,
+                          const Dims<4>& output_dims) {
   const int output_buffer_size = RequiredBufferSizeForDims(output_dims);
   std::vector<float> output_data(output_buffer_size);
   std::vector<float> reference_output_data(output_buffer_size);
-  reference_ops::DepthwiseConv<Ac>(input_data, input_dims, filter_data,
-                                   filter_dims, bias_data, bias_dims, stride,
-                                   pad_width, pad_height, depth_multiplier,
-                                   reference_output_data.data(), output_dims);
-  optimized_ops::DepthwiseConv<Ac>(input_data, input_dims, filter_data,
-                                   filter_dims, bias_data, bias_dims, stride,
-                                   pad_width, pad_height, depth_multiplier,
-                                   output_data.data(), output_dims);
+  reference_ops::DepthwiseConv(
+      input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+      stride, stride, dilation_width_factor, dilation_height_factor, pad_width,
+      pad_height, depth_multiplier, output_activation_min,
+      output_activation_max, reference_output_data.data(), output_dims);
+  optimized_ops::DepthwiseConv(
+      input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+      stride, stride, dilation_width_factor, dilation_height_factor, pad_width,
+      pad_height, depth_multiplier, output_activation_min,
+      output_activation_max, output_data.data(), output_dims);
+
   double sum_abs_diff = 0;
   float max_abs_val = 0;
   for (int i = 0; i < output_buffer_size; i++) {
@@ -59,27 +66,6 @@
   }
 }
 
-void TestOneDepthwiseConv(FusedActivationFunctionType Ac,
-                          const float* input_data, const Dims<4>& input_dims,
-                          const float* filter_data, const Dims<4>& filter_dims,
-                          const float* bias_data, const Dims<4>& bias_dims,
-                          int stride, int pad_width, int pad_height,
-                          int depth_multiplier, const Dims<4>& output_dims) {
-#define TOCO_HANDLE_CASE(AC_TYPE)                                            \
-  if (AC_TYPE == Ac) {                                                       \
-    TestOneDepthwiseConv<AC_TYPE>(input_data, input_dims, filter_data,       \
-                                  filter_dims, bias_data, bias_dims, stride, \
-                                  pad_width, pad_height, depth_multiplier,   \
-                                  output_dims);                              \
-    return;                                                                  \
-  }
-  TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone)
-  TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu)
-  TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1)
-  TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6)
-#undef TOCO_HANDLE_CASE
-}
-
 // This function picks some random DepthwiseConv params, which may or may not
 // be legal. If they're not legal, it returns false. If they're legal,
 // it runs the DepthwiseConv test and returns true. This allows the caller
@@ -99,6 +85,16 @@
   const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
   const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
   const int output_depth = input_depth * depth_multiplier;
+  const int dilation_width_factor = RandomElement(std::vector<int>({1, 2, 4}));
+  const int dilation_height_factor = RandomElement(std::vector<int>({1, 2, 4}));
+  float output_activation_min, output_activation_max;
+  FusedActivationFunctionType ac =
+      RandomElement(std::vector<FusedActivationFunctionType>(
+          {FusedActivationFunctionType::kNone,
+           FusedActivationFunctionType::kRelu,
+           FusedActivationFunctionType::kRelu1,
+           FusedActivationFunctionType::kRelu6}));
+  GetActivationMinMax(ac, &output_activation_min, &output_activation_max);
   // The optimized DepthwiseConv implementation currently uses a fixed-size
   // accumulator buffer on the stack, with that size. This currently means
   // that it does not support larger output depths. It CHECK's for it,
@@ -109,10 +105,6 @@
   if (output_depth > kMaxSupportedOutputDepth) {
     return false;
   }
-  const auto ac = RandomElement(std::vector<FusedActivationFunctionType>(
-      {FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu,
-       FusedActivationFunctionType::kRelu6,
-       FusedActivationFunctionType::kRelu1}));
   Dims<4> input_dims_inference =
       MakeDimsForInference(input_depth, input_width, input_height, batch);
   Dims<4> output_dims_inference;
@@ -120,7 +112,8 @@
   const auto padding_type =
       UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
   if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width,
-                        filter_height, stride, padding_type,
+                        filter_height, stride, dilation_width_factor,
+                        dilation_height_factor, padding_type,
                         &output_dims_inference, &pad_width, &pad_height)) {
     return false;
   }
@@ -140,10 +133,12 @@
   FillRandom(&input_data, -input_amplitude, input_amplitude);
   FillRandom(&filter_data, -filter_amplitude, filter_amplitude);
   FillRandom(&bias_data, -bias_amplitude, bias_amplitude);
-  TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference,
+  TestOneDepthwiseConv(input_data.data(), input_dims_inference,
                        filter_data.data(), filter_dims_inference,
-                       bias_data.data(), bias_dims_inference, stride, pad_width,
-                       pad_height, depth_multiplier, output_dims_inference);
+                       bias_data.data(), bias_dims_inference, stride,
+                       dilation_width_factor, dilation_height_factor, pad_width,
+                       pad_height, depth_multiplier, output_activation_min,
+                       output_activation_max, output_dims_inference);
   return true;
 }
 
diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
index 2c0fc84..312d048 100644
--- a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
@@ -199,6 +199,7 @@
 bool TryTestDepthwiseConv(int batch, int input_depth, int input_width,
                           int input_height, int filter_width, int filter_height,
                           int depth_multiplier, int stride,
+                          int dilation_width_factor, int dilation_height_factor,
                           PaddingType padding_type) {
   const int output_depth = input_depth * depth_multiplier;
   // The optimized DepthwiseConv implementation currently uses a fixed-size
@@ -231,7 +232,8 @@
   Dims<4> output_dims_inference;
   int pad_width, pad_height;
   if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width,
-                        filter_height, stride, padding_type,
+                        filter_height, stride, dilation_width_factor,
+                        dilation_height_factor, padding_type,
                         &output_dims_inference, &pad_width, &pad_height)) {
     return false;
   }
@@ -274,12 +276,15 @@
   const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10);
   const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
   const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+  const int dilation_width_factor = RandomElement(std::vector<int>({1, 2, 4}));
+  const int dilation_height_factor = RandomElement(std::vector<int>({1, 2, 4}));
   const auto padding_type =
       UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
 
   return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
                               filter_width, filter_height, depth_multiplier,
-                              stride, padding_type);
+                              stride, dilation_width_factor,
+                              dilation_height_factor, padding_type);
 }
 
 // Tests parameters for the 3x3 filter kernel.
@@ -292,6 +297,9 @@
   const int filter_height = 3;
   const int depth_multiplier = 1;
   const int stride = UniformRandomInt(1, 2);
+  // We don't support dilations in the 3x3 filter.
+  const int dilation_width_factor = 1;
+  const int dilation_height_factor = 1;
   // Although the kernel supports only kValid padding, we test that kSame
   // is using the correct code path.
   const auto padding_type =
@@ -299,7 +307,8 @@
 
   return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
                               filter_width, filter_height, depth_multiplier,
-                              stride, padding_type);
+                              stride, dilation_width_factor,
+                              dilation_height_factor, padding_type);
 }
 
 void TestOneDepthwiseConv() {
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
index 70810ca..114575a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -761,7 +761,8 @@
 // Accumulates the effect of one row of the filter, on a segment of one row
 // of the output, accessing the corresponding one row of the input.
 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
-void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
+void FloatDepthwiseConvAccumRow(int stride, int dilation_factor,
+                                int input_depth, int input_width,
                                 const float* input_data, int pad_width,
                                 int depth_multiplier, int filter_width,
                                 const float* filter_data,
@@ -835,10 +836,10 @@
 
 // generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized.
 inline void FloatDepthwiseConvAccumRowGeneric(
-    int stride, int input_depth, int input_width, const float* input_data,
-    int pad_width, int depth_multiplier, int filter_width,
-    const float* filter_data, int out_x_buffer_start, int out_x_buffer_end,
-    int output_depth, float* acc_buffer) {
+    int stride, int dilation_factor, int input_depth, int input_width,
+    const float* input_data, int pad_width, int depth_multiplier,
+    int filter_width, const float* filter_data, int out_x_buffer_start,
+    int out_x_buffer_end, int output_depth, float* acc_buffer) {
   gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
 #ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
 #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
@@ -860,6 +861,7 @@
       << "* stride = " << stride << "\n"
       << "* input_depth = " << input_depth << "\n"
       << "* depth_multiplier = " << depth_multiplier << "\n"
+      << "* dilation_factor = " << dilation_factor << "\n"
       << "*\n"
       << "* Please do not hesitate to contact benoitjacob@ with this\n"
       << "* information.\n"
@@ -869,14 +871,17 @@
   const float* filter_base_ptr = filter_data;
   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
     const int out_x_loop_start = std::max(
-        out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
-    const int out_x_loop_end =
-        std::min(out_x_buffer_end,
-                 (pad_width + input_width - filter_x + stride - 1) / stride);
+        out_x_buffer_start,
+        (pad_width - dilation_factor * filter_x + stride - 1) / stride);
+    const int out_x_loop_end = std::min(
+        out_x_buffer_end,
+        (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
+            stride);
 
     float* acc_buffer_ptr =
         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
-    const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+    const int in_x_origin =
+        (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
     const float* input_ptr = input_data + in_x_origin * input_depth;
     const int input_ptr_increment = (stride - 1) * input_depth;
     for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
@@ -907,25 +912,37 @@
   }
 }
 
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
-                          const float* filter_data, const Dims<4>& filter_dims,
-                          const float* bias_data, const Dims<4>& bias_dims,
-                          int stride_width, int stride_height, int pad_width,
-                          int pad_height, int depth_multiplier,
-                          float output_activation_min,
-                          float output_activation_max, float* output_data,
-                          const Dims<4>& output_dims) {
+inline void DepthwiseConv(
+    const DepthwiseParams& params, const RuntimeShape& input_shape,
+    const float* input_data, const RuntimeShape& filter_shape,
+    const float* filter_data, const RuntimeShape& bias_shape,
+    const float* bias_data, const RuntimeShape& output_shape,
+    float* output_data) {
   gemmlowp::ScopedProfilingLabel label("DepthwiseConv");
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
-  const int input_height = ArraySize(input_dims, 2);
-  const int input_width = ArraySize(input_dims, 1);
-  const int input_depth = ArraySize(input_dims, 0);
-  const int filter_height = ArraySize(filter_dims, 2);
-  const int filter_width = ArraySize(filter_dims, 1);
-  const int output_height = ArraySize(output_dims, 2);
-  const int output_width = ArraySize(output_dims, 1);
-  TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  const int depth_multiplier = params.depth_multiplier;
+  const float output_activation_min = params.float_activation_min;
+  const float output_activation_max = params.float_activation_max;
+  const int dilation_width_factor = params.dilation_width_factor;
+  const int dilation_height_factor = params.dilation_height_factor;
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int input_depth = input_shape.Dims(3);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+  TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+  TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
 
   static const int kAccBufferMaxSize = 2048;
   float acc_buffer[kAccBufferMaxSize];
@@ -946,7 +963,8 @@
                                         FIXED_DEPTH_MULTIPLIER)           \
   if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) &&          \
       (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) &&     \
-      depth_multiplier == FIXED_DEPTH_MULTIPLIER) {                       \
+      depth_multiplier == FIXED_DEPTH_MULTIPLIER &&                       \
+      dilation_height_factor == 1 && dilation_width_factor == 1) {        \
     row_accum_func =                                                      \
         FloatDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH,      \
                                    FIXED_DEPTH_MULTIPLIER>;               \
@@ -990,14 +1008,22 @@
     row_accum_func = FloatDepthwiseConvAccumRowGeneric;
   }
 
+  const int input_height_stride = input_shape.Dims(3) * input_shape.Dims(2);
+  const int input_batch_stride = input_height_stride * input_shape.Dims(1);
+  const int filter_height_stride = filter_shape.Dims(3) * filter_shape.Dims(2);
+
   // Now that we have determined row_accum_func, we can start work.
   float* output_ptr = output_data;
   for (int b = 0; b < batches; ++b) {
     for (int out_y = 0; out_y < output_height; ++out_y) {
       const int in_y_origin = (out_y * stride_height) - pad_height;
-      const int filter_y_start = std::max(0, -in_y_origin);
+      const int filter_y_start =
+          std::max(0, (-in_y_origin + dilation_height_factor - 1) /
+                          dilation_height_factor);
       const int filter_y_end =
-          std::min(filter_height, input_height - in_y_origin);
+          std::min(filter_height,
+                   (input_height - in_y_origin + dilation_height_factor - 1) /
+                       dilation_height_factor);
       for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
            out_x_buffer_start += kOutputPixelsInAccBuffer) {
         const int out_x_buffer_end = std::min(
@@ -1013,14 +1039,13 @@
         // Accumulation loop. Most of the time should be spent in here.
         for (int filter_y = filter_y_start; filter_y < filter_y_end;
              ++filter_y) {
-          const int in_y = in_y_origin + filter_y;
-          row_accum_func(stride_width, input_depth, input_width,
-                         input_data + in_y * input_dims.strides[2] +
-                             b * input_dims.strides[3],
-                         pad_width, depth_multiplier, filter_width,
-                         filter_data + filter_y * filter_dims.strides[2],
-                         out_x_buffer_start, out_x_buffer_end, output_depth,
-                         acc_buffer);
+          const int in_y = in_y_origin + dilation_height_factor * filter_y;
+          row_accum_func(
+              stride_width, dilation_width_factor, input_depth, input_width,
+              input_data + in_y * input_height_stride + b * input_batch_stride,
+              pad_width, depth_multiplier, filter_width,
+              filter_data + filter_y * filter_height_stride, out_x_buffer_start,
+              out_x_buffer_end, output_depth, acc_buffer);
         }
         // Finished accumulating. Now store to destination.
         const int num_output_values = output_depth * num_output_pixels;
@@ -1067,6 +1092,8 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
                           const float* filter_data, const Dims<4>& filter_dims,
                           const float* bias_data, const Dims<4>& bias_dims,
@@ -1076,17 +1103,40 @@
                           float output_activation_min,
                           float output_activation_max, float* output_data,
                           const Dims<4>& output_dims) {
-  // TODO(suharshs): Optimized implementation of dilation depthwise conv need to
-  // be implemented.
-  TFLITE_DCHECK(dilation_width_factor == 1);
-  TFLITE_DCHECK(dilation_height_factor == 1);
+  tflite::DepthwiseParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
+  op_params.dilation_width_factor = dilation_width_factor;
+  op_params.dilation_height_factor = dilation_height_factor;
+  op_params.depth_multiplier = depth_multiplier;
+  op_params.float_activation_min = output_activation_min;
+  op_params.float_activation_max = output_activation_max;
 
-  DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
-                bias_dims, stride_width, stride_height, pad_width, pad_height,
-                depth_multiplier, output_activation_min, output_activation_max,
-                output_data, output_dims);
+  DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+                DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+                bias_data, DimsToShape(output_dims), output_data);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+                          const float* filter_data, const Dims<4>& filter_dims,
+                          const float* bias_data, const Dims<4>& bias_dims,
+                          int stride_width, int stride_height, int pad_width,
+                          int pad_height, int depth_multiplier,
+                          float output_activation_min,
+                          float output_activation_max, float* output_data,
+                          const Dims<4>& output_dims) {
+  DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+                bias_dims, stride_width, stride_height, 1, 1, pad_width,
+                pad_height, depth_multiplier, output_activation_min,
+                output_activation_max, output_data, output_dims);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
@@ -1103,6 +1153,7 @@
                 output_data, output_dims);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index f707279..f892b8f 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -24,6 +24,9 @@
 namespace tflite {
 namespace optimized_ops {
 
+// TODO(b/80418076): Move to legacy ops file, along with invocations.
+static constexpr int kDepthwiseReverseShift = -1;
+
 // Implementation of quantized DepthwiseConv
 
 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
@@ -1466,11 +1469,14 @@
 // Accumulates the effect of one row of the filter, on a segment of one row
 // of the output, accessing the corresponding one row of the input.
 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
-void QuantizedDepthwiseConvAccumRow(
-    int stride, int input_depth, int input_width, const uint8* input_data,
-    int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
-    const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
-    int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor,
+                                    int input_depth, int input_width,
+                                    const uint8* input_data, int16 input_offset,
+                                    int pad_width, int depth_multiplier,
+                                    int filter_width, const uint8* filter_data,
+                                    int16 filter_offset, int out_x_buffer_start,
+                                    int out_x_buffer_end, int output_depth,
+                                    int32* acc_buffer) {
 #ifdef GEMMLOWP_PROFILING
   gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
 #endif
@@ -1537,10 +1543,11 @@
 
 // generic fallback of DepthwiseConvAccumRow, portable, non-templatized.
 inline void QuantizedDepthwiseConvAccumRowGeneric(
-    int stride, int input_depth, int input_width, const uint8* input_data,
-    int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
-    const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
-    int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+    int stride, int dilation_factor, int input_depth, int input_width,
+    const uint8* input_data, int16 input_offset, int pad_width,
+    int depth_multiplier, int filter_width, const uint8* filter_data,
+    int16 filter_offset, int out_x_buffer_start, int out_x_buffer_end,
+    int output_depth, int32* acc_buffer) {
   gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
 #ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
 #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
@@ -1562,6 +1569,7 @@
       << "* stride = " << stride << "\n"
       << "* input_depth = " << input_depth << "\n"
       << "* depth_multiplier = " << depth_multiplier << "\n"
+      << "* dilation_factor = " << dilation_factor << "\n"
       << "*\n"
       << "* Please do not hesitate to contact benoitjacob@ with this\n"
       << "* information.\n"
@@ -1571,14 +1579,17 @@
   const uint8* filter_base_ptr = filter_data;
   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
     const int out_x_loop_start = std::max(
-        out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
-    const int out_x_loop_end =
-        std::min(out_x_buffer_end,
-                 (pad_width + input_width - filter_x + stride - 1) / stride);
+        out_x_buffer_start,
+        (pad_width - dilation_factor * filter_x + stride - 1) / stride);
+    const int out_x_loop_end = std::min(
+        out_x_buffer_end,
+        (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
+            stride);
 
     int32* acc_buffer_ptr =
         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
-    const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+    const int in_x_origin =
+        (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
     const uint8* input_ptr = input_data + in_x_origin * input_depth;
     const int input_ptr_increment = (stride - 1) * input_depth;
     for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
@@ -1669,33 +1680,46 @@
   }
 }
 
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
-                          int32 input_offset, const uint8* filter_data,
-                          const Dims<4>& filter_dims, int32 filter_offset,
-                          const int32* bias_data, const Dims<4>& bias_dims,
-                          int stride_width, int stride_height, int pad_width,
-                          int pad_height, int depth_multiplier,
-                          int32 output_offset, int32 output_multiplier,
-                          int output_shift, int32 output_activation_min,
-                          int32 output_activation_max, uint8* output_data,
-                          const Dims<4>& output_dims) {
+inline void DepthwiseConv(
+    const DepthwiseParams& params, const RuntimeShape& input_shape,
+    const uint8* input_data, const RuntimeShape& filter_shape,
+    const uint8* filter_data, const RuntimeShape& bias_shape,
+    const int32* bias_data, const RuntimeShape& output_shape,
+    uint8* output_data) {
   gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit");
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  const int depth_multiplier = params.depth_multiplier;
+  const int32 output_activation_min = params.quantized_activation_min;
+  const int32 output_activation_max = params.quantized_activation_max;
+  const int32 input_offset = params.input_offset;
+  const int32 filter_offset = params.weights_offset;
+  const int32 output_offset = params.output_offset;
+  const int32 output_multiplier = params.output_multiplier;
+  const int output_shift = params.output_shift;
+  const int dilation_width_factor = params.dilation_width_factor;
+  const int dilation_height_factor = params.dilation_height_factor;
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
-
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
-  const int input_height = ArraySize(input_dims, 2);
-  const int input_width = ArraySize(input_dims, 1);
-  const int input_depth = ArraySize(input_dims, 0);
-  const int filter_height = ArraySize(filter_dims, 2);
-  const int filter_width = ArraySize(filter_dims, 1);
-  const int output_height = ArraySize(output_dims, 2);
-  const int output_width = ArraySize(output_dims, 1);
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int input_depth = input_shape.Dims(3);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
 #ifdef USE_NEON
-  const bool shift_left = (output_shift <= 0);
-  const int32 multiplier_power_of_two = shift_left ? (1 << -output_shift) : 1;
+  const bool shift_left = (output_shift > 0);
+  const int32 multiplier_power_of_two = shift_left ? (1 << output_shift) : 1;
 #endif
-  TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+  TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+  TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
 
 // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
 // Jetson TX-2. This compiler does not support the offsetof() macro.
@@ -1703,14 +1727,12 @@
   // Call kernel optimized for depthwise convolutions using 3x3 filters if
   // parameters are supported.
   if (Fast3x3FilterKernelSupported(
-          input_dims, filter_dims, stride_width, stride_height, pad_width,
-          pad_height, depth_multiplier, output_dims, output_shift)) {
-    DepthwiseConv3x3Filter(input_data, input_dims, input_offset, filter_data,
-                           filter_dims, filter_offset, bias_data, bias_dims,
-                           stride_width, stride_height, pad_width, pad_height,
-                           depth_multiplier, output_offset, output_multiplier,
-                           output_shift, output_activation_min,
-                           output_activation_max, output_data, output_dims);
+          input_shape, filter_shape, stride_width, stride_height,
+          dilation_width_factor, dilation_height_factor, pad_width, pad_height,
+          depth_multiplier, output_shape, output_shift)) {
+    DepthwiseConv3x3Filter(params, input_shape, input_data, filter_shape,
+                           filter_data, bias_shape, bias_data, output_shape,
+                           output_data);
     return;
   }
 #endif
@@ -1734,7 +1756,8 @@
                                         FIXED_DEPTH_MULTIPLIER)           \
   if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) &&          \
       (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) &&     \
-      depth_multiplier == FIXED_DEPTH_MULTIPLIER) {                       \
+      depth_multiplier == FIXED_DEPTH_MULTIPLIER &&                       \
+      dilation_width_factor == 1 && dilation_height_factor == 1) {        \
     row_accum_func =                                                      \
         QuantizedDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH,  \
                                        FIXED_DEPTH_MULTIPLIER>;           \
@@ -1785,14 +1808,22 @@
 
 #undef TFMINI_USE_DEPTHWISECONV_KERNEL
 
+  const int input_height_stride = input_shape.Dims(3) * input_shape.Dims(2);
+  const int input_batch_stride = input_height_stride * input_shape.Dims(1);
+  const int filter_height_stride = filter_shape.Dims(3) * filter_shape.Dims(2);
+
   // Now that we have determined row_accum_func, we can start work.
   uint8* output_ptr = output_data;
   for (int b = 0; b < batches; ++b) {
     for (int out_y = 0; out_y < output_height; ++out_y) {
       const int in_y_origin = (out_y * stride_height) - pad_height;
-      const int filter_y_start = std::max(0, -in_y_origin);
+      const int filter_y_start =
+          std::max(0, (-in_y_origin + dilation_height_factor - 1) /
+                          dilation_height_factor);
       const int filter_y_end =
-          std::min(filter_height, input_height - in_y_origin);
+          std::min(filter_height,
+                   (input_height - in_y_origin + dilation_height_factor - 1) /
+                       dilation_height_factor);
       for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
            out_x_buffer_start += kOutputPixelsInAccBuffer) {
         const int out_x_buffer_end = std::min(
@@ -1808,13 +1839,12 @@
         // Accumulation loop. Most of the time should be spent in here.
         for (int filter_y = filter_y_start; filter_y < filter_y_end;
              ++filter_y) {
-          const int in_y = in_y_origin + filter_y;
+          const int in_y = in_y_origin + dilation_height_factor * filter_y;
           row_accum_func(
-              stride_width, input_depth, input_width,
-              input_data + in_y * input_dims.strides[2] +
-                  b * input_dims.strides[3],
+              stride_width, dilation_width_factor, input_depth, input_width,
+              input_data + in_y * input_height_stride + b * input_batch_stride,
               input_offset, pad_width, depth_multiplier, filter_width,
-              filter_data + filter_y * filter_dims.strides[2], filter_offset,
+              filter_data + filter_y * filter_height_stride, filter_offset,
               out_x_buffer_start, out_x_buffer_end, output_depth, acc_buffer);
         }
         // Finished accumulating int32 values. Now need to convert them to
@@ -1845,7 +1875,7 @@
               acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
             }
             for (int j = 0; j < 4; j++) {
-              acc[j] = RoundingDivideByPOT(acc[j], output_shift);
+              acc[j] = RoundingDivideByPOT(acc[j], -output_shift);
             }
           } else {
             // Fixed-point multiplication.
@@ -1889,8 +1919,8 @@
             acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
             acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
             // Rounding right shift.
-            acc0 = RoundingDivideByPOT(acc0, output_shift);
-            acc1 = RoundingDivideByPOT(acc1, output_shift);
+            acc0 = RoundingDivideByPOT(acc0, -output_shift);
+            acc1 = RoundingDivideByPOT(acc1, -output_shift);
           } else {
             // Fixed-point multiplication.
             acc0 = vmulq_n_s32(acc0, multiplier_power_of_two);
@@ -1926,7 +1956,7 @@
             // Fixed-point multiplication.
             acc = vqrdmulhq_n_s32(acc, output_multiplier);
             // Rounding right shift.
-            acc = RoundingDivideByPOT(acc, output_shift);
+            acc = RoundingDivideByPOT(acc, -output_shift);
           } else {
             // Fixed-point multiplication.
             acc = vmulq_n_s32(acc, multiplier_power_of_two);
@@ -1953,7 +1983,7 @@
         for (; i < num_output_values; i++) {
           int32 acc = acc_buffer[i];
           acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
-                                              -output_shift);
+                                              output_shift);
           acc += output_offset;
           acc = std::max(acc, output_activation_min);
           acc = std::min(acc, output_activation_max);
@@ -1964,6 +1994,8 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
                           int32 input_offset, const uint8* filter_data,
                           const Dims<4>& filter_dims, int32 filter_offset,
@@ -1975,19 +2007,49 @@
                           int output_shift, int32 output_activation_min,
                           int32 output_activation_max, uint8* output_data,
                           const Dims<4>& output_dims) {
-  // TODO(suharshs): Optimized implementation of dilation depthwise is not
-  // supported yet.
-  TFLITE_DCHECK(dilation_width_factor == 1);
-  TFLITE_DCHECK(dilation_height_factor == 1);
+  tflite::DepthwiseParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
+  op_params.dilation_width_factor = dilation_width_factor;
+  op_params.dilation_height_factor = dilation_height_factor;
+  op_params.depth_multiplier = depth_multiplier;
+  op_params.quantized_activation_min = output_activation_min;
+  op_params.quantized_activation_max = output_activation_max;
+  op_params.input_offset = input_offset;
+  op_params.weights_offset = filter_offset;
+  op_params.output_offset = output_offset;
+  op_params.output_multiplier = output_multiplier;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+  op_params.output_shift = kDepthwiseReverseShift * output_shift;
 
+  DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+                DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+                bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+                          int32 input_offset, const uint8* filter_data,
+                          const Dims<4>& filter_dims, int32 filter_offset,
+                          const int32* bias_data, const Dims<4>& bias_dims,
+                          int stride_width, int stride_height, int pad_width,
+                          int pad_height, int depth_multiplier,
+                          int32 output_offset, int32 output_multiplier,
+                          int output_shift, int32 output_activation_min,
+                          int32 output_activation_max, uint8* output_data,
+                          const Dims<4>& output_dims) {
   DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
                 filter_offset, bias_data, bias_dims, stride_width,
-                stride_height, pad_width, pad_height, depth_multiplier,
+                stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
                 output_offset, output_multiplier, output_shift,
                 output_activation_min, output_activation_max, output_data,
                 output_dims);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // Legacy, for compatibility with old checked-in code.
 template <FusedActivationFunctionType Ac>
 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
@@ -2011,6 +2073,7 @@
                 output_dims);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // Legacy, for compatibility with old checked-in code.
 template <FusedActivationFunctionType Ac>
 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
index 0ce64f8..4809ddd 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
@@ -49,7 +49,7 @@
   int32 output_multiplier;
   int32 output_activation_min;
   int32 output_activation_max;
-  int32 output_shift;
+  int32 output_right_shift;
   int32 input_width;
   int32 input_height;
   int32 stride_width;
@@ -75,7 +75,7 @@
 #define OFFSET_OUTPUT_MULTIPLIER 52
 #define OFFSET_OUTPUT_ACTIVATION_MIN 56
 #define OFFSET_OUTPUT_ACTIVATION_MAX 60
-#define OFFSET_OUTPUT_SHIFT 64
+#define OFFSET_OUTPUT_RIGHT_SHIFT 64
 #define OFFSET_INPUT_WIDTH 68
 #define OFFSET_INPUT_HEIGHT 72
 #define OFFSET_STRIDE_WIDTH 76
@@ -105,8 +105,8 @@
                   OFFSET_OUTPUT_ACTIVATION_MIN, "");
 static_assert(offsetof(DepthwiseConvParams, output_activation_max) ==
                   OFFSET_OUTPUT_ACTIVATION_MAX, "");
-static_assert(offsetof(DepthwiseConvParams, output_shift) ==
-                  OFFSET_OUTPUT_SHIFT, "");
+static_assert(offsetof(DepthwiseConvParams, output_right_shift) ==
+                  OFFSET_OUTPUT_RIGHT_SHIFT, "");
 static_assert(offsetof(DepthwiseConvParams, input_width) ==
                   OFFSET_INPUT_WIDTH, "");
 static_assert(offsetof(DepthwiseConvParams, input_height) ==
@@ -189,7 +189,7 @@
         "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
         "ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
         "dup v27.4s, w9\n"
-        "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+        "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
         "dup v29.4s, w2\n"
         "ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
         "dup v30.4s, w4\n"
@@ -1166,7 +1166,7 @@
         // values from time to time when there are not enough NEON registers.
         // We use x9--x15 general purpose registers as they are caller-saved
         // temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf).  // NOLINT
-        "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+        "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
         "ldr w0, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
         "cmp %w[output_window_height], #2\n"
         "dup v28.8h, w0\n"
@@ -2216,7 +2216,7 @@
         "dup v27.4s, w10\n"
         "ld1 {v0.8b}, [%[filter_ptr]], #8\n"
         "cmp x11, #16\n"
-        "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+        "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
         "dup v28.4s, w9\n"
         "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
         "neg w10, w10\n"
@@ -2355,7 +2355,7 @@
         "dup v26.8h, w6\n"
         "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
         "dup v27.4s, w7\n"
-        "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+        "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
         "dup v28.4s, w6\n"
         "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
         "neg w7, w7\n"
@@ -2532,7 +2532,7 @@
         "dup v26.8h, w12\n"
         "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
         "dup v27.4s, w13\n"
-        "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+        "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
         "dup v28.4s, w12\n"
         "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
         "neg w13, w13\n"
@@ -2739,7 +2739,7 @@
         "dup v26.8h, w12\n"
         "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
         "dup v27.4s, w13\n"
-        "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+        "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
         "dup v28.4s, w12\n"
         "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
         "neg w13, w13\n"
@@ -2910,7 +2910,7 @@
 #undef OFFSET_OUTPUT_MULTIPLIER
 #undef OFFSET_OUTPUT_ACTIVATION_MIN
 #undef OFFSET_OUTPUT_ACTIVATION_MAX
-#undef OFFSET_OUTPUT_SHIFT
+#undef OFFSET_OUTPUT_RIGHT_SHIFT
 #undef OFFSET_INPUT_WIDTH
 #undef OFFSET_INPUT_HEIGHT
 #undef OFFSET_OUTPUT_WIDTH
@@ -3175,16 +3175,18 @@
 }
 
 inline bool Fast3x3FilterKernelSupported(
-    const Dims<4>& input_dims, const Dims<4>& filter_dims, int32 stride_width,
-    int32 stride_height, int32 pad_width, int32 pad_height,
-    int32 depth_multiplier, const Dims<4>& output_dims, int32 output_shift) {
-  const int32 input_height = ArraySize(input_dims, 2);
-  const int32 input_width = ArraySize(input_dims, 1);
-  const int32 input_depth = ArraySize(input_dims, 0);
-  const int32 filter_height = ArraySize(filter_dims, 2);
-  const int32 filter_width = ArraySize(filter_dims, 1);
-  const int32 output_height = ArraySize(output_dims, 2);
-  const int32 output_width = ArraySize(output_dims, 1);
+    const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
+    int32 stride_width, int32 stride_height, int32 dilation_width_factor,
+    int32 dilation_height_factor, int32 pad_width, int32 pad_height,
+    int32 depth_multiplier, const RuntimeShape& output_shape,
+    int32 output_shift) {
+  const int32 input_height = input_shape.Dims(1);
+  const int32 input_width = input_shape.Dims(2);
+  const int32 input_depth = input_shape.Dims(3);
+  const int32 filter_height = filter_shape.Dims(1);
+  const int32 filter_width = filter_shape.Dims(2);
+  const int32 output_height = output_shape.Dims(1);
+  const int32 output_width = output_shape.Dims(2);
 
   bool supported =
       filter_width == 3 && filter_height == 3 && depth_multiplier == 1 &&
@@ -3192,7 +3194,8 @@
       (stride_height == 1 || stride_height == 2) &&
       (stride_width == stride_height) && (pad_width == 0 || pad_width == 1) &&
       (pad_height == 0 || pad_height == 1) && (pad_width == pad_height) &&
-      (input_depth % 8) == 0 && (output_shift > 0);
+      (input_depth % 8) == 0 && (output_shift <= 0) &&
+      dilation_width_factor == 1 && dilation_height_factor == 1;
 
   if (!supported) {
     return false;
@@ -3234,36 +3237,47 @@
 }
 
 inline void DepthwiseConv3x3Filter(
-    const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
-    const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
-    const int32* bias_data, const Dims<4>& bias_dims, int32 stride_width,
-    int32 stride_height, int32 pad_width, int32 pad_height,
-    int32 depth_multiplier, int32 output_offset, int32 output_multiplier,
-    int32 output_shift, int32 output_activation_min,
-    int32 output_activation_max, uint8* output_data,
-    const Dims<4>& output_dims) {
+    const DepthwiseParams& rt_params, const RuntimeShape& input_shape,
+    const uint8* input_data, const RuntimeShape& filter_shape,
+    const uint8* filter_data, const RuntimeShape& bias_shape,
+    const int32* bias_data, const RuntimeShape& output_shape,
+    uint8* output_data) {
   gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
   DepthwiseConvParams params;
-  params.input_depth = ArraySize(input_dims, 0);
-  params.input_width = ArraySize(input_dims, 1);
-  params.input_height = ArraySize(input_dims, 2);
+
+  const int32 stride_width = rt_params.stride_width;
+  const int32 stride_height = rt_params.stride_height;
+  const int32 pad_width = rt_params.padding_values.width;
+  const int32 pad_height = rt_params.padding_values.height;
+  const int32 depth_multiplier = rt_params.depth_multiplier;
+  const int32 output_activation_min = rt_params.quantized_activation_min;
+  const int32 output_activation_max = rt_params.quantized_activation_max;
+  const int32 input_offset = rt_params.input_offset;
+  const int32 filter_offset = rt_params.weights_offset;
+  const int32 output_offset = rt_params.output_offset;
+  const int32 output_multiplier = rt_params.output_multiplier;
+  const int32 output_shift = rt_params.output_shift;
+
+  params.input_depth = input_shape.Dims(3);
+  params.input_width = input_shape.Dims(2);
+  params.input_height = input_shape.Dims(1);
   params.input_row_size = params.input_depth * params.input_width;
   params.input_offset = input_offset;
   params.stride_width = stride_width;
   params.stride_height = stride_height;
-  params.output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
-  params.output_width = ArraySize(output_dims, 1);
-  params.output_height = ArraySize(output_dims, 2);
+  params.output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+  params.output_width = output_shape.Dims(2);
+  params.output_height = output_shape.Dims(1);
   params.output_row_size = params.output_depth * params.output_width;
   params.output_offset = output_offset;
   params.filter_offset = filter_offset;
   params.output_multiplier = output_multiplier;
-  params.output_shift = output_shift;
+  params.output_right_shift = -output_shift;
   params.output_activation_min = output_activation_min;
   params.output_activation_max = output_activation_max;
 
-  const int32 filter_height = ArraySize(filter_dims, 2);
-  const int32 filter_width = ArraySize(filter_dims, 1);
+  const int32 filter_height = filter_shape.Dims(1);
+  const int32 filter_width = filter_shape.Dims(2);
   params.filter_row_size = params.output_depth * filter_width;
 
   // Algorithm assumes below constraints. It is optimized for depth
@@ -3279,7 +3293,7 @@
   TFLITE_DCHECK(pad_width == 0 || pad_width == 1);
   TFLITE_DCHECK(pad_width == pad_height);
 
-  const int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+  const int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
   const int64_t input_batch_size = params.input_row_size * params.input_height;
   const int64_t output_batch_size =
       params.output_row_size * params.output_height;
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 2741817..36c15db 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -457,7 +457,7 @@
     return;
   }
   *scaling_factor = range / kScale;
-  const float scaling_factor_inv = 1.0f / *scaling_factor;
+  const float scaling_factor_inv = kScale / range;
 
   const int postamble_start =
       size - (size & (2 * kFloatWeightsPerNeonLane - 1));
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index aaf93ae..c86e549 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -81,20 +81,16 @@
 using reference_ops::SpaceToBatchND;
 using reference_ops::Split;
 using reference_ops::StridedSlice;
+using reference_ops::TensorFlowSplit;
 using reference_ops::Transpose;
 
 // TODO(b/80247582) Remove this constant.
 // This will be phased out as the shifts are revised with more thought. Use of a
 // constant enables us to track progress on this work.
 //
-// Used mainly to convert from old-style shifts (right) to new-style (left).
+// Used to convert from old-style shifts (right) to new-style (left).
 static constexpr int kReverseShift = -1;
 
-inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
-  return RuntimeShape(
-      {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
-}
-
 // Make a local VectorMap typedef allowing to map a float array
 // as a Eigen vector expression. The std::conditional here is to
 // construct the suitable Eigen type for the constness of the
@@ -188,6 +184,15 @@
   return ArrayMap<Scalar>(data, rows, cols);
 }
 
+template <typename Scalar>
+ArrayMap<Scalar> MapAsArrayWithLastDimAsRows(Scalar* data,
+                                             const RuntimeShape& shape) {
+  const int dims_count = shape.DimensionsCount();
+  const int rows = shape.Dims(dims_count - 1);
+  const int cols = FlatSizeSkipDim(shape, dims_count - 1);
+  return ArrayMap<Scalar>(data, rows, cols);
+}
+
 // Copied from tensorflow/core/framework/tensor_types.h
 template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
 struct TTypes {
@@ -972,7 +977,7 @@
   const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
                                       output_shape, output_dim_count - 1);
   static constexpr int kPeel = 4;
-  const bool shift_left = (output_shift <= 0);
+  const bool shift_left = (output_shift > 0);
   for (int k = 0; k < input_size; k += 64) {
     optimized_ops_preload_l1_stream(input_data + k);
   }
@@ -1085,7 +1090,7 @@
     bias_ptr += 4;
     reduced = vaddq_s32(reduced, bias_vec);
     if (shift_left) {
-      const int32 multiplier_power_of_two = 1 << -output_shift;
+      const int32 multiplier_power_of_two = 1 << output_shift;
       reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
     } else {
@@ -1093,7 +1098,7 @@
       reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
       // Rounding-shift-right.
       using gemmlowp::RoundingDivideByPOT;
-      reduced = RoundingDivideByPOT(reduced, output_shift);
+      reduced = RoundingDivideByPOT(reduced, -output_shift);
     }
     // Add the output offset.
     const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
@@ -1190,7 +1195,7 @@
   gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
       output_data, output_rows, batches, output_rows);
   const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
-      bias_data, output_rows, output_offset, output_multiplier, -output_shift,
+      bias_data, output_rows, output_offset, output_multiplier, output_shift,
       output_activation_min, output_activation_max);
   gemmlowp::GemmWithOutputPipeline<uint8, uint8,
                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
@@ -1214,7 +1219,8 @@
   op_params.weights_offset = filter_offset;
   op_params.output_offset = output_offset;
   op_params.output_multiplier = output_multiplier;
-  op_params.output_shift = output_shift;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+  op_params.output_shift = kReverseShift * output_shift;
   op_params.quantized_activation_min = output_activation_min;
   op_params.quantized_activation_max = output_activation_max;
 
@@ -1269,14 +1275,14 @@
     if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
       GEMVForLstmCellWithSymmetricRange(
           input_shape, input_data, filter_shape, filter_data, bias_shape,
-          bias_data_int32, output_multiplier, -output_shift, output_shape,
+          bias_data_int32, output_multiplier, output_shift, output_shape,
           output_data);
       return;
     }
     if (!(output_depth % 4) && !(accum_depth % 8)) {
       GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data,
                       filter_offset, bias_shape, bias_data_int32,
-                      output_multiplier, -output_shift, output_shape,
+                      output_multiplier, output_shift, output_shape,
                       output_data);
       return;
     }
@@ -1297,7 +1303,7 @@
   scale_stage.result_offset_after_shift = 0;
   scale_stage.result_fixedpoint_multiplier = output_multiplier;
   // Note that this shift is negated wrt ordinary FC.
-  scale_stage.result_exponent = -output_shift;
+  scale_stage.result_exponent = output_shift;
   gemmlowp::OutputStageClamp clamp_stage;
   clamp_stage.min = output_activation_min;
   clamp_stage.max = output_activation_max;
@@ -1325,7 +1331,8 @@
   op_params.weights_offset = filter_offset;
   op_params.output_offset = output_offset;
   op_params.output_multiplier = output_multiplier;
-  op_params.output_shift = output_shift;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+  op_params.output_shift = kReverseShift * output_shift;
   op_params.quantized_activation_min = output_activation_min;
   op_params.quantized_activation_max = output_activation_max;
 
@@ -1371,8 +1378,8 @@
 #if defined USE_NEON
   const int8* shuffled_weights_ptr = shuffled_weights_data;
   if (batches == 1) {
-    const int right_shift = output_shift > 0 ? output_shift : 0;
-    const int left_shift = output_shift > 0 ? 0 : -output_shift;
+    const int right_shift = output_shift > 0 ? 0 : -output_shift;
+    const int left_shift = output_shift > 0 ? output_shift : 0;
     for (int c = 0; c < output_depth; c += 4) {
       // Accumulation loop.
       int32x4_t row_accum0 = vdupq_n_s32(0);
@@ -1438,8 +1445,8 @@
       vst1_s16(output_data + c, res16);
     }
   } else if (batches == 4) {
-    const int right_shift = output_shift > 0 ? output_shift : 0;
-    const int left_shift = output_shift > 0 ? 0 : -output_shift;
+    const int right_shift = output_shift > 0 ? 0 : -output_shift;
+    const int left_shift = output_shift > 0 ? output_shift : 0;
     for (int c = 0; c < output_depth; c += 4) {
       const int8* shuffled_input_ptr =
           reinterpret_cast<const int8*>(shuffled_input_workspace_data);
@@ -1570,8 +1577,8 @@
         // (16-bit, typically 3 integer bits) fixed-point format. The quantized
         // multiplier and shift here have been pre-computed offline
         // (e.g. by toco).
-        acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
-                                            -output_shift);
+        acc =
+            MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
         // Saturate, cast to int16, and store to output array.
         acc = std::max(acc, -32768);
         acc = std::min(acc, 32767);
@@ -1622,7 +1629,7 @@
           // quantized multiplier and shift here have been pre-computed offline
           // (e.g. by toco).
           acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
-                                              -output_shift);
+                                              output_shift);
           // Saturate, cast to int16, and store to output array.
           acc = std::max(acc, -32768);
           acc = std::min(acc, 32767);
@@ -1813,7 +1820,8 @@
     uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
   tflite::FullyConnectedParams op_params;
   op_params.output_multiplier = output_multiplier;
-  op_params.output_shift = output_shift;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+  op_params.output_shift = kReverseShift * output_shift;
   op_params.quantized_activation_min = output_activation_min;
   op_params.quantized_activation_max = output_activation_max;
 
@@ -2205,7 +2213,6 @@
   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
-  TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 4);
 
   const int batch_size = input_shape.Dims(0);
   const int filter_width = filter_shape.Dims(2);
@@ -2371,7 +2378,6 @@
   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
-  TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 4);
 
   const uint8* gemm_input_data = nullptr;
   const RuntimeShape* gemm_input_shape = nullptr;
@@ -2434,7 +2440,7 @@
   gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
       output_data, output_rows, output_cols);
   const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
-      bias_data, output_rows, output_offset, output_multiplier, -output_shift,
+      bias_data, output_rows, output_offset, output_multiplier, output_shift,
       output_activation_min, output_activation_max);
   gemmlowp::GemmWithOutputPipeline<uint8, uint8,
                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
@@ -2468,7 +2474,8 @@
   op_params.weights_offset = filter_offset;
   op_params.output_offset = output_offset;
   op_params.output_multiplier = output_multiplier;
-  op_params.output_shift = output_shift;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+  op_params.output_shift = kReverseShift * output_shift;
   op_params.quantized_activation_min = output_activation_min;
   op_params.quantized_activation_max = output_activation_max;
 
@@ -2789,6 +2796,7 @@
     *output_inv_sqrt <<= -*output_shift;
     *output_shift = 0;
   }
+  // Convert right shift (right is positive) to left shift.
   *output_shift *= kReverseShift;
 }
 
@@ -3633,62 +3641,96 @@
   }
 }
 
-inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
-                     const float* prev_activ_data,
-                     const Dims<4>& prev_activ_dims, const float* weights_data,
-                     const Dims<4>& weights_dims, const float* bias_data,
-                     const Dims<4>& bias_dims, const float* prev_state_data,
-                     const Dims<4>& prev_state_dims, float* output_state_data,
-                     const Dims<4>& output_state_dims, float* output_activ_data,
-                     const Dims<4>& output_activ_dims, float* concat_temp_data,
-                     const Dims<4>& concat_temp_dims, float* activ_temp_data,
-                     const Dims<4>& activ_temp_dims) {
+inline void LstmCell(
+    const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+    const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
+    const float* prev_activ_data, const RuntimeShape& weights_shape,
+    const float* weights_data, const RuntimeShape& unextended_bias_shape,
+    const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
+    const float* prev_state_data,
+    const RuntimeShape& unextended_output_state_shape, float* output_state_data,
+    const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
+    const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
+    const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
   gemmlowp::ScopedProfilingLabel label("LstmCell");
-  MatchingArraySize(  // batches
-      input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, output_state_dims,
-      3, output_activ_dims, 3);
-  MatchingArraySize(  // height
-      input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, output_state_dims,
-      2, output_activ_dims, 2);
-  MatchingArraySize(  // width
-      input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, output_state_dims,
-      1, output_activ_dims, 1);
-  TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
-  TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
-  const int input_depth = ArraySize(input_dims, 0);
-  const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+  TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+  const RuntimeShape input_shape =
+      RuntimeShape::ExtendedShape(4, unextended_input_shape);
+  const RuntimeShape prev_activ_shape =
+      RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+  const RuntimeShape bias_shape =
+      RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+  const RuntimeShape prev_state_shape =
+      RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+  const RuntimeShape output_state_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+  const RuntimeShape output_activ_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+  const RuntimeShape concat_temp_shape =
+      RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+  const RuntimeShape activ_temp_shape =
+      RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+  TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
+  const int weights_dim_count = weights_shape.DimensionsCount();
+  MatchingDim(  // batches
+      input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
+      output_state_shape, 0, output_activ_shape, 0);
+  MatchingDim(  // height
+      input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
+      output_state_shape, 1, output_activ_shape, 1);
+  MatchingDim(  // width
+      input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
+      output_state_shape, 2, output_activ_shape, 2);
+  const int input_depth = input_shape.Dims(3);
+  const int prev_activ_depth = prev_activ_shape.Dims(3);
   const int total_input_depth = prev_activ_depth + input_depth;
-  TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
-  TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
-                  1);
+  TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+                   total_input_depth);
+  TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
   const int intern_activ_depth =
-      MatchingArraySize(weights_dims, 1, bias_dims, 0);
-  TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+      MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+  TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+                   intern_activ_depth * total_input_depth);
+  TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
   const int output_depth =
-      MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
-                        output_state_dims, 0, output_activ_dims, 0);
-  TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
+      MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+                  3, output_activ_shape, 3);
+  TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
 
   // Concatenate prev_activ and input data together
   std::vector<float const*> concat_input_arrays_data;
-  std::vector<Dims<4> const*> concat_input_arrays_dims;
+  std::vector<RuntimeShape const*> concat_input_arrays_shapes;
   concat_input_arrays_data.push_back(input_data);
   concat_input_arrays_data.push_back(prev_activ_data);
-  concat_input_arrays_dims.push_back(&input_dims);
-  concat_input_arrays_dims.push_back(&prev_activ_dims);
-  Concatenation<FusedActivationFunctionType::kNone, float>(
-      0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
-      concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
+  concat_input_arrays_shapes.push_back(&input_shape);
+  concat_input_arrays_shapes.push_back(&prev_activ_shape);
+  tflite::ConcatenationParams concat_params;
+  concat_params.axis = 3;
+  concat_params.inputs_count = concat_input_arrays_data.size();
+  Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
+                &(concat_input_arrays_data[0]), concat_temp_shape,
+                concat_temp_data);
 
   // Fully connected
-  FullyConnected<FusedActivationFunctionType::kNone>(
-      concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
-      bias_dims, activ_temp_data, activ_temp_dims);
+  tflite::FullyConnectedParams fc_params;
+  fc_params.float_activation_min = std::numeric_limits<float>::lowest();
+  fc_params.float_activation_max = std::numeric_limits<float>::max();
+  FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
+                 weights_data, bias_shape, bias_data, activ_temp_shape,
+                 activ_temp_data);
 
   // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
   // operations.
   ArrayMap<float> activ_temp_map =
-      MapAsArrayWithFirstDimAsRows(activ_temp_data, activ_temp_dims);
+      MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
   auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
                                             activ_temp_map.cols());
   auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
@@ -3698,11 +3740,11 @@
   auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
                                              activ_temp_map.cols());
   ArrayMap<const float> prev_state_map =
-      MapAsArrayWithFirstDimAsRows(prev_state_data, prev_state_dims);
+      MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
   ArrayMap<float> output_state_map =
-      MapAsArrayWithFirstDimAsRows(output_state_data, output_state_dims);
+      MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
   ArrayMap<float> output_activ_map =
-      MapAsArrayWithFirstDimAsRows(output_activ_data, output_activ_dims);
+      MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
 
   // Combined memory state and final output calculation
   gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput");
@@ -3716,56 +3758,120 @@
       output_state_map.tanh();
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+                     const float* prev_activ_data,
+                     const Dims<4>& prev_activ_dims, const float* weights_data,
+                     const Dims<4>& weights_dims, const float* bias_data,
+                     const Dims<4>& bias_dims, const float* prev_state_data,
+                     const Dims<4>& prev_state_dims, float* output_state_data,
+                     const Dims<4>& output_state_dims, float* output_activ_data,
+                     const Dims<4>& output_activ_dims, float* concat_temp_data,
+                     const Dims<4>& concat_temp_dims, float* activ_temp_data,
+                     const Dims<4>& activ_temp_dims) {
+  tflite::LstmCellParams op_params;
+  // Float LSTM cell does not need parameters to be set: leave untouched.
+
+  LstmCell(op_params, DimsToShape(input_dims), input_data,
+           DimsToShape(prev_activ_dims), prev_activ_data,
+           DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
+           bias_data, DimsToShape(prev_state_dims), prev_state_data,
+           DimsToShape(output_state_dims), output_state_data,
+           DimsToShape(output_activ_dims), output_activ_data,
+           DimsToShape(concat_temp_dims), concat_temp_data,
+           DimsToShape(activ_temp_dims), activ_temp_data);
+}
+
 // Quantized LSTM cell. Currently just a copy of the reference impl in
 // reference_ops.h. See the big function comment there, not replicating it
 // here.
 template <int StateIntegerBits>
-void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
-              const uint8* prev_activ_data_uint8,
-              const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
-              const Dims<4>& weights_dims, const int32* bias_data_int32,
-              const Dims<4>& bias_dims, const int16* prev_state_data_int16,
-              const Dims<4>& prev_state_dims, int16* output_state_data_int16,
-              const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
-              const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
-              const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
-              const Dims<4>& activ_temp_dims, int32 weights_zero_point,
-              int32 accum_multiplier, int accum_shift,
-              gemmlowp::GemmContext* gemm_context) {
+inline void LstmCell(
+    const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+    const uint8* input_data_uint8,
+    const RuntimeShape& unextended_prev_activ_shape,
+    const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
+    const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
+    const int32* bias_data_int32,
+    const RuntimeShape& unextended_prev_state_shape,
+    const int16* prev_state_data_int16,
+    const RuntimeShape& unextended_output_state_shape,
+    int16* output_state_data_int16,
+    const RuntimeShape& unextended_output_activ_shape,
+    uint8* output_activ_data_uint8,
+    const RuntimeShape& unextended_concat_temp_shape,
+    uint8* concat_temp_data_uint8,
+    const RuntimeShape& unextended_activ_temp_shape,
+    int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
   gemmlowp::ScopedProfilingLabel label(
       "LstmCell/quantized (8bit external, 16bit internal)");
+  int32 weights_zero_point = params.weights_zero_point;
+  int32 accum_multiplier = params.accum_multiplier;
+  int accum_shift = params.accum_shift;
+  TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+  const RuntimeShape input_shape =
+      RuntimeShape::ExtendedShape(4, unextended_input_shape);
+  const RuntimeShape prev_activ_shape =
+      RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+  const RuntimeShape bias_shape =
+      RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+  const RuntimeShape prev_state_shape =
+      RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+  const RuntimeShape output_state_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+  const RuntimeShape output_activ_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+  const RuntimeShape concat_temp_shape =
+      RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+  const RuntimeShape activ_temp_shape =
+      RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+  TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
   // Gather dimensions information, and perform consistency checks.
-  const int outer_size =
-      MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prev_state_dims,
-                              output_state_dims, output_activ_dims);
-  TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
-  TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
-  const int input_depth = ArraySize(input_dims, 0);
-  const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+  const int weights_dim_count = weights_shape.DimensionsCount();
+  const int outer_size = MatchingFlatSizeSkipDim(
+      input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
+      output_activ_shape);
+  const int input_depth = input_shape.Dims(3);
+  const int prev_activ_depth = prev_activ_shape.Dims(3);
   const int total_input_depth = prev_activ_depth + input_depth;
-  TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
-  TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
-                  1);
+  TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+                   total_input_depth);
   const int intern_activ_depth =
-      MatchingArraySize(weights_dims, 1, bias_dims, 0);
-  TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+      MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+  TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+                   intern_activ_depth * total_input_depth);
+  TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
+  TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
   const int output_depth =
-      MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
-                        output_state_dims, 0, output_activ_dims, 0);
-  TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
-  const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
+      MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+                  3, output_activ_shape, 3);
+  TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
+  const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
   const int fc_output_depth =
-      MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
-  const int fc_accum_depth = ArraySize(weights_dims, 0);
-  TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
+      MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
+  const int fc_accum_depth = total_input_depth;
+  TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
 
   // Depth-concatenate prev_activ and input data together.
   uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
                                               prev_activ_data_uint8};
-  Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
-  Concatenation<FusedActivationFunctionType::kNone, uint8>(
-      0, concat_input_arrays_data, concat_input_arrays_dims, 2,
-      concat_temp_data_uint8, concat_temp_dims);
+  const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
+                                                       &prev_activ_shape};
+  tflite::ConcatenationParams concat_params;
+  concat_params.axis = 3;
+  concat_params.inputs_count = 2;
+  Concatenation(concat_params, concat_input_arrays_shapes,
+                concat_input_arrays_data, concat_temp_shape,
+                concat_temp_data_uint8);
 
   // Implementation of the fully connected node inside the LSTM cell.
   // The operands are 8-bit integers, the accumulators are internally 32bit
@@ -3775,11 +3881,10 @@
   bool gemm_already_performed = false;
 #ifdef GEMMLOWP_NEON
   if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) {
-    GEMVForLstmCell(DimsToShape(concat_temp_dims), concat_temp_data_uint8,
-                    DimsToShape(weights_dims), weights_data_uint8,
-                    weights_zero_point, DimsToShape(bias_dims), bias_data_int32,
-                    accum_multiplier, accum_shift, DimsToShape(activ_temp_dims),
-                    activ_temp_data_int16);
+    GEMVForLstmCell(concat_temp_shape, concat_temp_data_uint8, weights_shape,
+                    weights_data_uint8, weights_zero_point, bias_shape,
+                    bias_data_int32, accum_multiplier, accum_shift,
+                    activ_temp_shape, activ_temp_data_int16);
     gemm_already_performed = true;
   }
 #endif
@@ -3968,28 +4073,35 @@
   }
 }
 
-template <FusedActivationFunctionType Ac, typename Scalar>
-void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
-                     int outputs_count, Scalar* const* output_data,
-                     const Dims<4>* const* output_dims) {
-  gemmlowp::ScopedProfilingLabel label("TensorFlowSplit");
-  TFLITE_DCHECK_GE(outputs_count, 1);
-  for (int i = 0; i < outputs_count; i++) {
-    MatchingFlatSizeSkipDim(*output_dims[i], 0, input_dims);
-  }
-  const int outer_size = FlatSizeSkipDim(input_dims, 0);
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  // For now we don't have a model with a TensorFlowSplit
-  // with fused activation function.
-  TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
-  const Scalar* input_ptr = input_data;
-  for (int k = 0; k < outer_size; k++) {
-    for (int i = 0; i < outputs_count; ++i) {
-      memcpy(output_data[i] + k * output_dims[i]->sizes[0], input_ptr,
-             output_dims[i]->sizes[0] * sizeof(Scalar));
-      input_ptr += output_dims[i]->sizes[0];
-    }
-  }
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <int StateIntegerBits>
+void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
+              const uint8* prev_activ_data_uint8,
+              const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
+              const Dims<4>& weights_dims, const int32* bias_data_int32,
+              const Dims<4>& bias_dims, const int16* prev_state_data_int16,
+              const Dims<4>& prev_state_dims, int16* output_state_data_int16,
+              const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
+              const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
+              const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
+              const Dims<4>& activ_temp_dims, int32 weights_zero_point,
+              int32 accum_multiplier, int accum_shift,
+              gemmlowp::GemmContext* gemm_context) {
+  tflite::LstmCellParams op_params;
+  op_params.weights_zero_point = weights_zero_point;
+  op_params.accum_multiplier = accum_multiplier;
+  op_params.accum_shift = accum_shift;
+
+  LstmCell<StateIntegerBits>(
+      op_params, DimsToShape(input_dims), input_data_uint8,
+      DimsToShape(prev_activ_dims), prev_activ_data_uint8,
+      DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
+      bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
+      DimsToShape(output_state_dims), output_state_data_int16,
+      DimsToShape(output_activ_dims), output_activ_data_uint8,
+      DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+      DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
 }
 
 inline int NodeOffset(int b, int h, int w, int height, int width) {
@@ -4431,9 +4543,9 @@
   }
 }
 
-inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
-                    float beta, float* output_data,
-                    const RuntimeShape& output_shape) {
+inline void Softmax(const SoftmaxParams& params,
+                    const RuntimeShape& input_shape, const float* input_data,
+                    const RuntimeShape& output_shape, float* output_data) {
   gemmlowp::ScopedProfilingLabel label("Softmax");
   MatchingFlatSize(input_shape, output_shape);
 
@@ -4441,7 +4553,8 @@
   auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
   // Compute the exponential first, removing the max coefficient for numerical
   // stability.
-  out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta;
+  out_mat =
+      (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
   // We are separating out the exp function so that exp can be vectorized.
   out_mat = out_mat.array().exp();
   // Normalize to get the activations.
@@ -4450,10 +4563,22 @@
   out_mat.array().rowwise() *= scale;
 }
 
-inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
-                    int32 input_beta_multiplier, int32 input_beta_left_shift,
-                    int diff_min, uint8* output_data,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
+                    float beta, float* output_data,
                     const RuntimeShape& output_shape) {
+  SoftmaxParams params;
+  params.beta = beta;
+  Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Softmax(const SoftmaxParams& params,
+                    const RuntimeShape& input_shape, const uint8* input_data,
+                    const RuntimeShape& output_shape, uint8* output_data) {
+  const int32 input_beta_multiplier = params.input_multiplier;
+  const int32 input_beta_left_shift = params.input_left_shift;
+  const int diff_min = params.diff_min;
   // The representation chosen for the input to the exp() function is Q5.26.
   // We need to leave extra space since values that we skip might be as large as
   // -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -4659,10 +4784,24 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
+                    int32 input_beta_multiplier, int32 input_beta_left_shift,
+                    int diff_min, uint8* output_data,
+                    const RuntimeShape& output_shape) {
+  SoftmaxParams params;
+  params.input_multiplier = input_beta_multiplier;
+  params.input_left_shift = input_beta_left_shift;
+  params.diff_min = diff_min;
+  Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
 // TODO(myenik): This is the same as the reference implementation, not actually
 // optimized yet.
-inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
-                       float* output_data, const RuntimeShape& output_shape) {
+inline void LogSoftmax(const SoftmaxParams& params,
+                       const RuntimeShape& input_shape, const float* input_data,
+                       const RuntimeShape& output_shape, float* output_data) {
   gemmlowp::ScopedProfilingLabel label("LogSoftmax");
   const int trailing_dim = input_shape.DimensionsCount() - 1;
   const int outer_size =
@@ -4695,6 +4834,15 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+                       float* output_data, const RuntimeShape& output_shape) {
+  SoftmaxParams params;
+  // No params currently used for float LogSoftmax.
+  LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
 template <int OutputIntegerBits, int InputIntegerBits>
 inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
 log_x_for_x_greater_than_or_equal_to_1_impl(
@@ -4809,12 +4957,15 @@
 }
 
 // Currently just a copy of the reference code.
-inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
-                       int32 input_multiplier, int32 input_left_shift,
-                       int32 reverse_scaling_divisor,
-                       int32 reverse_scaling_right_shift, int diff_min,
-                       uint8* output_data, const RuntimeShape& output_shape) {
+inline void LogSoftmax(const SoftmaxParams& params,
+                       const RuntimeShape& input_shape, const uint8* input_data,
+                       const RuntimeShape& output_shape, uint8* output_data) {
   gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8");
+  const int32 input_multiplier = params.input_multiplier;
+  const int32 input_left_shift = params.input_left_shift;
+  const int32 reverse_scaling_divisor = params.reverse_scaling_divisor;
+  const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift;
+  const int diff_min = params.diff_min;
   // The representation chosen for the input to the exp() function is Q5.26.
   // We need to leave extra space since values that we skip might be as large as
   // -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -4872,7 +5023,7 @@
         std::max(diff_min - 1,  // Note use of > below instead of >= above.
                  MultiplyByQuantizedMultiplierSmallerThanOneExp(
                      rescaled_diff_min, reverse_scaling_divisor,
-                     kReverseShift * reverse_scaling_right_shift));
+                     -reverse_scaling_right_shift));
 
     for (int c = 0; c < depth; ++c) {
       int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
@@ -4896,6 +5047,22 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
+                       int32 input_multiplier, int32 input_left_shift,
+                       int32 reverse_scaling_divisor,
+                       int32 reverse_scaling_right_shift, int diff_min,
+                       uint8* output_data, const RuntimeShape& output_shape) {
+  SoftmaxParams params;
+  params.input_multiplier = input_multiplier;
+  params.input_left_shift = input_left_shift;
+  params.reverse_scaling_divisor = reverse_scaling_divisor;
+  params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+  params.diff_min = diff_min;
+  LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
 inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
                      const RuntimeShape& output_shape, float* output_data) {
   gemmlowp::ScopedProfilingLabel label("Logistic");
@@ -4905,11 +5072,23 @@
       input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
 }
 
-inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
-                     int32 input_zero_point, int32 input_range_radius,
-                     int32 input_multiplier, int input_left_shift,
-                     uint8* output_data, const RuntimeShape& output_shape) {
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+                     const float* input_data, const RuntimeShape& output_shape,
+                     float* output_data) {
+  // Drop params: not needed.
+  Logistic(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+                     const RuntimeShape& input_shape, const uint8* input_data,
+                     const RuntimeShape& output_shape, uint8* output_data) {
   gemmlowp::ScopedProfilingLabel label("Logistic/Uint8");
+  const int32 input_zero_point = params.input_zero_point;
+  const int32 input_range_radius = params.input_range_radius;
+  const int32 input_multiplier = params.input_multiplier;
+  const int input_left_shift = params.input_left_shift;
   const int size = MatchingFlatSize(input_shape, output_shape);
 
   int c = 0;
@@ -5042,7 +5221,22 @@
   }
 }
 
-inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
+                     int32 input_zero_point, int32 input_range_radius,
+                     int32 input_multiplier, int input_left_shift,
+                     uint8* output_data, const RuntimeShape& output_shape) {
+  LogisticParams params;
+  params.input_zero_point = input_zero_point;
+  params.input_range_radius = input_range_radius;
+  params.input_multiplier = input_multiplier;
+  params.input_left_shift = input_left_shift;
+  Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+                     const RuntimeShape& input_shape, const int16* input_data,
                      const RuntimeShape& output_shape, int16* output_data) {
   gemmlowp::ScopedProfilingLabel label("Logistic/Int16");
   const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -5102,10 +5296,22 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy version.
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+                     const RuntimeShape& output_shape, int16* output_data) {
+  LogisticParams params;
+  // No params currently needed by int16 Logistic.
+  Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // Legacy version.
 inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
                      int16* output_data, const RuntimeShape& output_shape) {
-  Logistic(input_shape, input_data, output_shape, output_data);
+  LogisticParams params;
+  // No params currently needed by int16 Logistic.
+  Logistic(params, input_shape, input_data, output_shape, output_data);
 }
 
 inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
@@ -5116,12 +5322,24 @@
   output_map.array() = input_map.array().tanh();
 }
 
-inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
-                 int32 input_zero_point, int32 input_range_radius,
-                 int32 input_multiplier, int input_left_shift,
-                 uint8* output_data, const RuntimeShape& output_shape) {
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+                 const float* input_data, const RuntimeShape& output_shape,
+                 float* output_data) {
+  // Drop params: not needed.
+  Tanh(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+                 const uint8* input_data, const RuntimeShape& output_shape,
+                 uint8* output_data) {
   // Note that this is almost the exact same code as in Logistic().
   gemmlowp::ScopedProfilingLabel label("Tanh");
+  const int32 input_zero_point = params.input_zero_point;
+  const int32 input_range_radius = params.input_range_radius;
+  const int32 input_multiplier = params.input_multiplier;
+  const int input_left_shift = params.input_left_shift;
   const int size = MatchingFlatSize(input_shape, output_shape);
 
   int c = 0;
@@ -5263,10 +5481,25 @@
   }
 }
 
-inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
-                 int input_left_shift, int16* output_data,
-                 const RuntimeShape& output_shape) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
+                 int32 input_zero_point, int32 input_range_radius,
+                 int32 input_multiplier, int input_left_shift,
+                 uint8* output_data, const RuntimeShape& output_shape) {
+  TanhParams params;
+  params.input_zero_point = input_zero_point;
+  params.input_range_radius = input_range_radius;
+  params.input_multiplier = input_multiplier;
+  params.input_left_shift = input_left_shift;
+  Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+                 const int16* input_data, const RuntimeShape& output_shape,
+                 int16* output_data) {
   gemmlowp::ScopedProfilingLabel label("Tanh/Int16");
+  const int input_left_shift = params.input_left_shift;
   // Support for shifts is limited until we have a parameterized version of
   // SaturatingRoundingMultiplyByPOT().
   TFLITE_DCHECK_GE(input_left_shift, 0);
@@ -5363,6 +5596,16 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
+                 int input_left_shift, int16* output_data,
+                 const RuntimeShape& output_shape) {
+  TanhParams params;
+  params.input_left_shift = input_left_shift;
+  Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
 template <typename SrcT, typename DstT>
 inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
                  const RuntimeShape& output_shape, DstT* output_data) {
@@ -6140,6 +6383,16 @@
   output_map.array() = input1_map.array().min(min_value);
 }
 
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+                    const RuntimeShape&, const T* input2_data,
+                    const RuntimeShape& output_shape, T* output_data) {
+  // Drop shape of second input: not needed.
+  Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
 template <typename T>
 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
              const T* input2_data, const RuntimeShape& output_shape,
@@ -6151,6 +6404,16 @@
   output_map.array() = input1_map.array().max(max_value);
 }
 
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+                    const RuntimeShape&, const T* input2_data,
+                    const RuntimeShape& output_shape, T* output_data) {
+  // Drop shape of second input: not needed.
+  Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
 template <typename T>
 void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
                      const RuntimeShape& input_shape, const T* input_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
index bb5d590..a842852 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
@@ -22,25 +22,36 @@
 namespace tflite {
 namespace reference_ops {
 
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
-                          const float* filter_data, const Dims<4>& filter_dims,
-                          const float* bias_data, const Dims<4>& bias_dims,
-                          int stride_width, int stride_height,
-                          int dilation_width_factor, int dilation_height_factor,
-                          int pad_width, int pad_height, int depth_multiplier,
-                          float output_activation_min,
-                          float output_activation_max, float* output_data,
-                          const Dims<4>& output_dims) {
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
-  const int input_height = ArraySize(input_dims, 2);
-  const int input_width = ArraySize(input_dims, 1);
-  const int input_depth = ArraySize(input_dims, 0);
-  const int filter_height = ArraySize(filter_dims, 2);
-  const int filter_width = ArraySize(filter_dims, 1);
-  const int output_height = ArraySize(output_dims, 2);
-  const int output_width = ArraySize(output_dims, 1);
-  TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+inline void DepthwiseConv(
+    const DepthwiseParams& params, const RuntimeShape& input_shape,
+    const float* input_data, const RuntimeShape& filter_shape,
+    const float* filter_data, const RuntimeShape& bias_shape,
+    const float* bias_data, const RuntimeShape& output_shape,
+    float* output_data) {
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int dilation_width_factor = params.dilation_width_factor;
+  const int dilation_height_factor = params.dilation_height_factor;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  const int depth_multiplier = params.depth_multiplier;
+  const float output_activation_min = params.float_activation_min;
+  const float output_activation_max = params.float_activation_max;
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int input_depth = input_shape.Dims(3);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+  TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+  TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
 
   for (int b = 0; b < batches; ++b) {
     for (int out_y = 0; out_y < output_height; ++out_y) {
@@ -61,18 +72,18 @@
                 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
                     (in_y < input_height)) {
                   float input_value =
-                      input_data[Offset(input_dims, ic, in_x, in_y, b)];
+                      input_data[Offset(input_shape, b, in_y, in_x, ic)];
                   float filter_value = filter_data[Offset(
-                      filter_dims, oc, filter_x, filter_y, 0)];
+                      filter_shape, 0, filter_y, filter_x, oc)];
                   total += (input_value * filter_value);
                 }
               }
             }
             float bias_value = 0.0f;
             if (bias_data) {
-              bias_value = bias_data[Offset(bias_dims, oc, 0, 0, 0)];
+              bias_value = bias_data[oc];
             }
-            output_data[Offset(output_dims, oc, out_x, out_y, b)] =
+            output_data[Offset(output_shape, b, out_y, out_x, oc)] =
                 ActivationFunctionWithMinMax(total + bias_value,
                                              output_activation_min,
                                              output_activation_max);
@@ -83,6 +94,37 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+                          const float* filter_data, const Dims<4>& filter_dims,
+                          const float* bias_data, const Dims<4>& bias_dims,
+                          int stride_width, int stride_height,
+                          int dilation_width_factor, int dilation_height_factor,
+                          int pad_width, int pad_height, int depth_multiplier,
+                          float output_activation_min,
+                          float output_activation_max, float* output_data,
+                          const Dims<4>& output_dims) {
+  tflite::DepthwiseParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
+  op_params.dilation_width_factor = dilation_width_factor;
+  op_params.dilation_height_factor = dilation_height_factor;
+  op_params.depth_multiplier = depth_multiplier;
+  op_params.float_activation_min = output_activation_min;
+  op_params.float_activation_max = output_activation_max;
+
+  DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+                DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+                bias_data, DimsToShape(output_dims), output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
                           const float* filter_data, const Dims<4>& filter_dims,
                           const float* bias_data, const Dims<4>& bias_dims,
@@ -97,6 +139,7 @@
                 output_activation_max, output_data, output_dims);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // Legacy, for compatibility with old checked-in code.
 template <FusedActivationFunctionType Ac>
 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
@@ -113,6 +156,7 @@
                 output_data, output_dims);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // Legacy, for compatibility with old checked-in code.
 template <FusedActivationFunctionType Ac>
 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
index 5e3e899..ecc655c 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -26,27 +26,46 @@
 namespace tflite {
 namespace reference_ops {
 
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
-                          int32 input_offset, const uint8* filter_data,
-                          const Dims<4>& filter_dims, int32 filter_offset,
-                          const int32* bias_data, const Dims<4>& bias_dims,
-                          int stride_width, int stride_height,
-                          int dilation_width_factor, int dilation_height_factor,
-                          int pad_width, int pad_height, int depth_multiplier,
-                          int32 output_offset, int32 output_multiplier,
-                          int output_shift, int32 output_activation_min,
-                          int32 output_activation_max, uint8* output_data,
-                          const Dims<4>& output_dims) {
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
-  const int input_height = ArraySize(input_dims, 2);
-  const int input_width = ArraySize(input_dims, 1);
-  const int input_depth = ArraySize(input_dims, 0);
-  const int filter_height = ArraySize(filter_dims, 2);
-  const int filter_width = ArraySize(filter_dims, 1);
-  const int output_height = ArraySize(output_dims, 2);
-  const int output_width = ArraySize(output_dims, 1);
-  TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+// TODO(b/80418076): Move to legacy ops file, along with invocations.
+static constexpr int kDepthwiseReverseShift = -1;
+
+inline void DepthwiseConv(
+    const DepthwiseParams& params, const RuntimeShape& input_shape,
+    const uint8* input_data, const RuntimeShape& filter_shape,
+    const uint8* filter_data, const RuntimeShape& bias_shape,
+    const int32* bias_data, const RuntimeShape& output_shape,
+    uint8* output_data) {
+  gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit");
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int dilation_width_factor = params.dilation_width_factor;
+  const int dilation_height_factor = params.dilation_height_factor;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  const int depth_multiplier = params.depth_multiplier;
+  const int32 output_activation_min = params.quantized_activation_min;
+  const int32 output_activation_max = params.quantized_activation_max;
+  const int32 input_offset = params.input_offset;
+  const int32 filter_offset = params.weights_offset;
+  const int32 output_offset = params.output_offset;
+  const int32 output_multiplier = params.output_multiplier;
+  const int output_shift = params.output_shift;
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+  TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int input_depth = input_shape.Dims(3);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+  TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+  TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
 
   for (int b = 0; b < batches; ++b) {
     for (int out_y = 0; out_y < output_height; ++out_y) {
@@ -67,23 +86,23 @@
                 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
                     (in_y < input_height)) {
                   int32 input_val =
-                      input_data[Offset(input_dims, ic, in_x, in_y, b)];
-                  int32 filter_val = filter_data[Offset(filter_dims, oc,
-                                                        filter_x, filter_y, 0)];
+                      input_data[Offset(input_shape, b, in_y, in_x, ic)];
+                  int32 filter_val = filter_data[Offset(
+                      filter_shape, 0, filter_y, filter_x, oc)];
                   acc +=
                       (filter_val + filter_offset) * (input_val + input_offset);
                 }
               }
             }
             if (bias_data) {
-              acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)];
+              acc += bias_data[oc];
             }
             acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
-                                                -output_shift);
+                                                output_shift);
             acc += output_offset;
             acc = std::max(acc, output_activation_min);
             acc = std::min(acc, output_activation_max);
-            output_data[Offset(output_dims, oc, out_x, out_y, b)] =
+            output_data[Offset(output_shape, b, out_y, out_x, oc)] =
                 static_cast<uint8>(acc);
           }
         }
@@ -92,6 +111,44 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+                          int32 input_offset, const uint8* filter_data,
+                          const Dims<4>& filter_dims, int32 filter_offset,
+                          const int32* bias_data, const Dims<4>& bias_dims,
+                          int stride_width, int stride_height,
+                          int dilation_width_factor, int dilation_height_factor,
+                          int pad_width, int pad_height, int depth_multiplier,
+                          int32 output_offset, int32 output_multiplier,
+                          int output_shift, int32 output_activation_min,
+                          int32 output_activation_max, uint8* output_data,
+                          const Dims<4>& output_dims) {
+  tflite::DepthwiseParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
+  op_params.dilation_width_factor = dilation_width_factor;
+  op_params.dilation_height_factor = dilation_height_factor;
+  op_params.depth_multiplier = depth_multiplier;
+  op_params.quantized_activation_min = output_activation_min;
+  op_params.quantized_activation_max = output_activation_max;
+  op_params.input_offset = input_offset;
+  op_params.weights_offset = filter_offset;
+  op_params.output_offset = output_offset;
+  op_params.output_multiplier = output_multiplier;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+  op_params.output_shift = kDepthwiseReverseShift * output_shift;
+
+  DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+                DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+                bias_data, DimsToShape(output_dims), output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
                           int32 input_offset, const uint8* filter_data,
                           const Dims<4>& filter_dims, int32 filter_offset,
@@ -110,6 +167,7 @@
                 output_dims);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // Legacy, for compatibility with old checked-in code.
 template <FusedActivationFunctionType Ac>
 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
@@ -133,6 +191,7 @@
                 output_dims);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // Legacy, for compatibility with old checked-in code.
 template <FusedActivationFunctionType Ac>
 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index 77e60ad..70d25c4 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -55,7 +55,7 @@
     return;
   }
   *scaling_factor = range / kScale;
-  const float scaling_factor_inv = 1.0f / *scaling_factor;
+  const float scaling_factor_inv = kScale / range;
   for (int i = 0; i < size; ++i) {
     const int32_t quantized_value =
         static_cast<int32_t>(TfLiteRound(values[i] * scaling_factor_inv));
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 66f18ec..bb1d30b 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -105,11 +105,6 @@
 // Used mainly to convert from old-style shifts (right) to new-style (left).
 static constexpr int kReverseShift = -1;
 
-inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
-  return RuntimeShape(
-      {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
-}
-
 inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
   shape->BuildFrom(
       {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
@@ -168,28 +163,38 @@
       SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
 }
 
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
-                 const float* filter_data, const Dims<4>& filter_dims,
-                 const float* bias_data, const Dims<4>& bias_dims,
-                 int stride_width, int stride_height, int dilation_width_factor,
-                 int dilation_height_factor, int pad_width, int pad_height,
-                 float output_activation_min, float output_activation_max,
-                 float* output_data, const Dims<4>& output_dims,
-                 float* im2col_data, const Dims<4>& im2col_dims) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+                 const float* input_data, const RuntimeShape& filter_shape,
+                 const float* filter_data, const RuntimeShape& bias_shape,
+                 const float* bias_data, const RuntimeShape& output_shape,
+                 float* output_data, const RuntimeShape& im2col_shape,
+                 float* im2col_data) {
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int dilation_width_factor = params.dilation_width_factor;
+  const int dilation_height_factor = params.dilation_height_factor;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  const float output_activation_min = params.float_activation_min;
+  const float output_activation_max = params.float_activation_max;
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
   (void)im2col_data;  // only used in optimized code.
-  (void)im2col_dims;  // only used in optimized code.
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
-  const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
+  (void)im2col_shape;  // only used in optimized code.
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+  const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
   if (bias_data) {
-    TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0));
+    TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
   }
-  const int input_height = ArraySize(input_dims, 2);
-  const int input_width = ArraySize(input_dims, 1);
-  const int filter_height = ArraySize(filter_dims, 2);
-  const int filter_width = ArraySize(filter_dims, 1);
-  const int output_height = ArraySize(output_dims, 2);
-  const int output_width = ArraySize(output_dims, 1);
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
   for (int batch = 0; batch < batches; ++batch) {
     for (int out_y = 0; out_y < output_height; ++out_y) {
       for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -207,11 +212,11 @@
                 // use zero as a default value.
                 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
                     (in_y < input_height)) {
-                  float input_value = input_data[Offset(input_dims, in_channel,
-                                                        in_x, in_y, batch)];
+                  float input_value = input_data[Offset(
+                      input_shape, batch, in_y, in_x, in_channel)];
                   float filter_value =
-                      filter_data[Offset(filter_dims, in_channel, filter_x,
-                                         filter_y, out_channel)];
+                      filter_data[Offset(filter_shape, out_channel, filter_y,
+                                         filter_x, in_channel)];
                   total += (input_value * filter_value);
                 }
               }
@@ -219,9 +224,9 @@
           }
           float bias_value = 0.0f;
           if (bias_data) {
-            bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
+            bias_value = bias_data[out_channel];
           }
-          output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
+          output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
               ActivationFunctionWithMinMax(total + bias_value,
                                            output_activation_min,
                                            output_activation_max);
@@ -231,6 +236,35 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+                 const float* filter_data, const Dims<4>& filter_dims,
+                 const float* bias_data, const Dims<4>& bias_dims,
+                 int stride_width, int stride_height, int dilation_width_factor,
+                 int dilation_height_factor, int pad_width, int pad_height,
+                 float output_activation_min, float output_activation_max,
+                 float* output_data, const Dims<4>& output_dims,
+                 float* im2col_data, const Dims<4>& im2col_dims) {
+  tflite::ConvParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
+  op_params.dilation_width_factor = dilation_width_factor;
+  op_params.dilation_height_factor = dilation_height_factor;
+  op_params.float_activation_min = output_activation_min;
+  op_params.float_activation_max = output_activation_max;
+
+  Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+       filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+       output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 template <FusedActivationFunctionType Ac>
 void Conv(const float* input_data, const Dims<4>& input_dims,
           const float* filter_data, const Dims<4>& filter_dims,
@@ -248,6 +282,7 @@
        im2col_dims);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 void Conv(const float* input_data, const Dims<4>& input_dims,
@@ -264,6 +299,7 @@
        im2col_data, im2col_dims);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 void Conv(const float* input_data, const Dims<4>& input_dims,
@@ -277,31 +313,45 @@
            output_dims, im2col_data, im2col_dims);
 }
 
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
-                 int32 input_offset, const uint8* filter_data,
-                 const Dims<4>& filter_dims, int32 filter_offset,
-                 const int32* bias_data, const Dims<4>& bias_dims,
-                 int stride_width, int stride_height, int dilation_width_factor,
-                 int dilation_height_factor, int pad_width, int pad_height,
-                 int32 output_offset, int32 output_multiplier, int output_shift,
-                 int32 output_activation_min, int32 output_activation_max,
-                 uint8* output_data, const Dims<4>& output_dims,
-                 uint8* im2col_data, const Dims<4>& im2col_dims,
-                 gemmlowp::GemmContext* gemm_context) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+                 const uint8* input_data, const RuntimeShape& filter_shape,
+                 const uint8* filter_data, const RuntimeShape& bias_shape,
+                 const int32* bias_data, const RuntimeShape& output_shape,
+                 uint8* output_data, const RuntimeShape& im2col_shape,
+                 uint8* im2col_data, gemmlowp::GemmContext* gemm_context) {
   (void)im2col_data;   // only used in optimized code.
-  (void)im2col_dims;   // only used in optimized code.
+  (void)im2col_shape;  // only used in optimized code.
   (void)gemm_context;  // only used in optimized code.
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int dilation_width_factor = params.dilation_width_factor;
+  const int dilation_height_factor = params.dilation_height_factor;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  const int32 input_offset = params.input_offset;
+  const int32 filter_offset = params.weights_offset;
+  const int32 output_offset = params.output_offset;
+  const int32 output_multiplier = params.output_multiplier;
+  const int output_shift = params.output_shift;
+  const int32 output_activation_min = params.quantized_activation_min;
+  const int32 output_activation_max = params.quantized_activation_max;
   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
-  const int output_depth =
-      MatchingArraySize(filter_dims, 3, bias_dims, 0, output_dims, 0);
-  const int input_height = ArraySize(input_dims, 2);
-  const int input_width = ArraySize(input_dims, 1);
-  const int filter_height = ArraySize(filter_dims, 2);
-  const int filter_width = ArraySize(filter_dims, 1);
-  const int output_height = ArraySize(output_dims, 2);
-  const int output_width = ArraySize(output_dims, 1);
+
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+  const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+  if (bias_data) {
+    TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+  }
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
   for (int batch = 0; batch < batches; ++batch) {
     for (int out_y = 0; out_y < output_height; ++out_y) {
       for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -319,11 +369,11 @@
                 // use zero as a default value.
                 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
                     (in_y < input_height)) {
-                  int32 input_val = input_data[Offset(input_dims, in_channel,
-                                                      in_x, in_y, batch)];
+                  int32 input_val = input_data[Offset(input_shape, batch, in_y,
+                                                      in_x, in_channel)];
                   int32 filter_val =
-                      filter_data[Offset(filter_dims, in_channel, filter_x,
-                                         filter_y, out_channel)];
+                      filter_data[Offset(filter_shape, out_channel, filter_y,
+                                         filter_x, in_channel)];
                   acc +=
                       (filter_val + filter_offset) * (input_val + input_offset);
                 }
@@ -331,14 +381,14 @@
             }
           }
           if (bias_data) {
-            acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
+            acc += bias_data[out_channel];
           }
           acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
-                                              kReverseShift * output_shift);
+                                              output_shift);
           acc += output_offset;
           acc = std::max(acc, output_activation_min);
           acc = std::min(acc, output_activation_max);
-          output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
+          output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
               static_cast<uint8>(acc);
         }
       }
@@ -346,6 +396,44 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+                 int32 input_offset, const uint8* filter_data,
+                 const Dims<4>& filter_dims, int32 filter_offset,
+                 const int32* bias_data, const Dims<4>& bias_dims,
+                 int stride_width, int stride_height, int dilation_width_factor,
+                 int dilation_height_factor, int pad_width, int pad_height,
+                 int32 output_offset, int32 output_multiplier, int output_shift,
+                 int32 output_activation_min, int32 output_activation_max,
+                 uint8* output_data, const Dims<4>& output_dims,
+                 uint8* im2col_data, const Dims<4>& im2col_dims,
+                 gemmlowp::GemmContext* gemm_context) {
+  tflite::ConvParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
+  op_params.dilation_width_factor = dilation_width_factor;
+  op_params.dilation_height_factor = dilation_height_factor;
+  op_params.input_offset = input_offset;
+  op_params.weights_offset = filter_offset;
+  op_params.output_offset = output_offset;
+  op_params.output_multiplier = output_multiplier;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+  op_params.output_shift = kReverseShift * output_shift;
+  op_params.quantized_activation_min = output_activation_min;
+  op_params.quantized_activation_max = output_activation_max;
+
+  Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+       filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+       output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
                  int32 input_offset, const uint8* filter_data,
                  const Dims<4>& filter_dims, int32 filter_offset,
@@ -364,6 +452,7 @@
        im2col_data, im2col_dims, gemm_context);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -393,6 +482,7 @@
        im2col_data, im2col_dims, gemm_context);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -516,24 +606,25 @@
   }
 }
 
-inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
-                           const float* weights_data,
-                           const Dims<4>& weights_dims, const float* bias_data,
-                           const Dims<4>& bias_dims,
-                           float output_activation_min,
-                           float output_activation_max, float* output_data,
-                           const Dims<4>& output_dims) {
+inline void FullyConnected(
+    const FullyConnectedParams& params, const RuntimeShape& input_shape,
+    const float* input_data, const RuntimeShape& weights_shape,
+    const float* weights_data, const RuntimeShape& bias_shape,
+    const float* bias_data, const RuntimeShape& output_shape,
+    float* output_data) {
+  const float output_activation_min = params.float_activation_min;
+  const float output_activation_max = params.float_activation_max;
   // TODO(benoitjacob): This really should be:
   //     const int batches = ArraySize(output_dims, 1);
   // but the current --variable_batch hack consists in overwriting the 3rd
   // dimension with the runtime batch size, as we don't keep track for each
   // array of which dimension is the batch dimension in it.
-  const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
-                      ArraySize(output_dims, 3);
-  const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
-  const int accum_depth = ArraySize(weights_dims, 0);
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
+  const int output_dims_count = output_shape.DimensionsCount();
+  const int weights_dims_count = weights_shape.DimensionsCount();
+  const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
+  const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
+                                       output_shape, output_dims_count - 1);
+  const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
   for (int b = 0; b < batches; ++b) {
     for (int out_c = 0; out_c < output_depth; ++out_c) {
       float total = 0.f;
@@ -543,7 +634,7 @@
       }
       float bias_value = 0.0f;
       if (bias_data) {
-        bias_value = bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
+        bias_value = bias_data[out_c];
       }
       output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax(
           total + bias_value, output_activation_min, output_activation_max);
@@ -551,6 +642,26 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+                           const float* weights_data,
+                           const Dims<4>& weights_dims, const float* bias_data,
+                           const Dims<4>& bias_dims,
+                           float output_activation_min,
+                           float output_activation_max, float* output_data,
+                           const Dims<4>& output_dims) {
+  tflite::FullyConnectedParams op_params;
+  op_params.float_activation_min = output_activation_min;
+  op_params.float_activation_max = output_activation_max;
+
+  FullyConnected(op_params, DimsToShape(input_dims), input_data,
+                 DimsToShape(weights_dims), weights_data,
+                 DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+                 output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 void FullyConnected(const float* input_data, const Dims<4>& input_dims,
@@ -564,28 +675,35 @@
                  output_data, output_dims);
 }
 
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
-                           int32 input_offset, const uint8* filter_data,
-                           const Dims<4>& filter_dims, int32 filter_offset,
-                           const int32* bias_data, const Dims<4>& bias_dims,
-                           int32 output_offset, int32 output_multiplier,
-                           int output_shift, int32 output_activation_min,
-                           int32 output_activation_max, uint8* output_data,
-                           const Dims<4>& output_dims,
-                           gemmlowp::GemmContext* gemm_context) {
+inline void FullyConnected(
+    const FullyConnectedParams& params, const RuntimeShape& input_shape,
+    const uint8* input_data, const RuntimeShape& filter_shape,
+    const uint8* filter_data, const RuntimeShape& bias_shape,
+    const int32* bias_data, const RuntimeShape& output_shape,
+    uint8* output_data, gemmlowp::GemmContext* gemm_context) {
   (void)gemm_context;  // only used in optimized code.
+  const int32 input_offset = params.input_offset;
+  const int32 filter_offset = params.weights_offset;
+  const int32 output_offset = params.output_offset;
+  const int32 output_multiplier = params.output_multiplier;
+  const int output_shift = params.output_shift;
+  const int32 output_activation_min = params.quantized_activation_min;
+  const int32 output_activation_max = params.quantized_activation_max;
+  TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+  TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+
   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
   // TODO(benoitjacob): This really should be:
   //     const int batches = ArraySize(output_dims, 1);
   // but the current --variable_batch hack consists in overwriting the 3rd
   // dimension with the runtime batch size, as we don't keep track for each
   // array of which dimension is the batch dimension in it.
-  const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
-                      ArraySize(output_dims, 3);
-  const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
-  const int accum_depth = ArraySize(filter_dims, 0);
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+  const int output_dim_count = output_shape.DimensionsCount();
+  const int filter_dim_count = filter_shape.DimensionsCount();
+  const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+  const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+                                       output_shape, output_dim_count - 1);
+  const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
   for (int b = 0; b < batches; ++b) {
     for (int out_c = 0; out_c < output_depth; ++out_c) {
       int32 acc = 0;
@@ -595,10 +713,9 @@
         acc += (filter_val + filter_offset) * (input_val + input_offset);
       }
       if (bias_data) {
-        acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
+        acc += bias_data[out_c];
       }
-      acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
-                                          kReverseShift * output_shift);
+      acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
       acc += output_offset;
       acc = std::max(acc, output_activation_min);
       acc = std::min(acc, output_activation_max);
@@ -607,16 +724,48 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
                            int32 input_offset, const uint8* filter_data,
                            const Dims<4>& filter_dims, int32 filter_offset,
                            const int32* bias_data, const Dims<4>& bias_dims,
                            int32 output_offset, int32 output_multiplier,
                            int output_shift, int32 output_activation_min,
-                           int32 output_activation_max, int16* output_data,
+                           int32 output_activation_max, uint8* output_data,
                            const Dims<4>& output_dims,
                            gemmlowp::GemmContext* gemm_context) {
+  tflite::FullyConnectedParams op_params;
+  op_params.input_offset = input_offset;
+  op_params.weights_offset = filter_offset;
+  op_params.output_offset = output_offset;
+  op_params.output_multiplier = output_multiplier;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+  op_params.output_shift = kReverseShift * output_shift;
+  op_params.quantized_activation_min = output_activation_min;
+  op_params.quantized_activation_max = output_activation_max;
+
+  FullyConnected(op_params, DimsToShape(input_dims), input_data,
+                 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+                 bias_data, DimsToShape(output_dims), output_data,
+                 gemm_context);
+}
+
+inline void FullyConnected(
+    const FullyConnectedParams& params, const RuntimeShape& input_shape,
+    const uint8* input_data, const RuntimeShape& filter_shape,
+    const uint8* filter_data, const RuntimeShape& bias_shape,
+    const int32* bias_data, const RuntimeShape& output_shape,
+    int16* output_data, gemmlowp::GemmContext* gemm_context) {
   (void)gemm_context;  // only used in optimized code.
+  const int32 input_offset = params.input_offset;
+  const int32 filter_offset = params.weights_offset;
+  const int32 output_offset = params.output_offset;
+  const int32 output_multiplier = params.output_multiplier;
+  const int output_shift = params.output_shift;
+  const int32 output_activation_min = params.quantized_activation_min;
+  const int32 output_activation_max = params.quantized_activation_max;
+
   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
   TFLITE_DCHECK_EQ(output_offset, 0);
   // TODO(benoitjacob): This really should be:
@@ -624,12 +773,12 @@
   // but the current --variable_batch hack consists in overwriting the 3rd
   // dimension with the runtime batch size, as we don't keep track for each
   // array of which dimension is the batch dimension in it.
-  const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
-                      ArraySize(output_dims, 3);
-  const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
-  const int accum_depth = ArraySize(filter_dims, 0);
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+  const int output_dim_count = output_shape.DimensionsCount();
+  const int filter_dim_count = filter_shape.DimensionsCount();
+  const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+  const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+                                       output_shape, output_dim_count - 1);
+  const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
   for (int b = 0; b < batches; ++b) {
     for (int out_c = 0; out_c < output_depth; ++out_c) {
       // Internal accumulation.
@@ -645,8 +794,8 @@
       // (16-bit, typically 3 integer bits) fixed-point format. The quantized
       // multiplier and shift here have been pre-computed offline
       // (e.g. by toco).
-      accum = MultiplyByQuantizedMultiplier(accum, output_multiplier,
-                                            -output_shift);
+      accum =
+          MultiplyByQuantizedMultiplier(accum, output_multiplier, output_shift);
       // Saturate, cast to int16, and store to output array.
       accum = std::max(accum, output_activation_min - output_offset);
       accum = std::min(accum, output_activation_max - output_offset);
@@ -656,27 +805,61 @@
   }
 }
 
-inline void ShuffledFullyConnected(
-    const uint8* input_data, const Dims<4>& input_dims,
-    const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
-    const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
-    int output_shift, int32 output_activation_min, int32 output_activation_max,
-    int16* output_data, const Dims<4>& output_dims,
-    uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
-  (void)gemm_context;  // only used in optimized code.
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+                           int32 input_offset, const uint8* filter_data,
+                           const Dims<4>& filter_dims, int32 filter_offset,
+                           const int32* bias_data, const Dims<4>& bias_dims,
+                           int32 output_offset, int32 output_multiplier,
+                           int output_shift, int32 output_activation_min,
+                           int32 output_activation_max, int16* output_data,
+                           const Dims<4>& output_dims,
+                           gemmlowp::GemmContext* gemm_context) {
+  tflite::FullyConnectedParams op_params;
+  op_params.input_offset = input_offset;
+  op_params.weights_offset = filter_offset;
+  op_params.output_offset = output_offset;
+  op_params.output_multiplier = output_multiplier;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+  op_params.output_shift = kReverseShift * output_shift;
+  op_params.quantized_activation_min = output_activation_min;
+  op_params.quantized_activation_max = output_activation_max;
 
+  FullyConnected(op_params, DimsToShape(input_dims), input_data,
+                 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+                 bias_data, DimsToShape(output_dims), output_data,
+                 gemm_context);
+}
+
+inline void ShuffledFullyConnected(
+    const FullyConnectedParams& params, const RuntimeShape& input_shape,
+    const uint8* input_data, const RuntimeShape& weights_shape,
+    const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
+    const int32* bias_data, const RuntimeShape& output_shape,
+    int16* output_data, uint8* shuffled_input_workspace_data,
+    gemmlowp::GemmContext* gemm_context) {
+  (void)gemm_context;  // only used in optimized code.
+  const int32 output_multiplier = params.output_multiplier;
+  const int output_shift = params.output_shift;
+  const int32 output_activation_min = params.quantized_activation_min;
+  const int32 output_activation_max = params.quantized_activation_max;
   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+
+  TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+  TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+  TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
   // TODO(benoitjacob): This really should be:
   //     const int batches = ArraySize(output_dims, 1);
   // but the current --variable_batch hack consists in overwriting the 3rd
   // dimension with the runtime batch size, as we don't keep track for each
   // array of which dimension is the batch dimension in it.
-  const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
-                      ArraySize(output_dims, 3);
-  const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
-  const int accum_depth = ArraySize(weights_dims, 0);
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
+  const int output_dim_count = output_shape.DimensionsCount();
+  const int weights_dim_count = weights_shape.DimensionsCount();
+  const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+  const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
+                                       output_shape, output_dim_count - 1);
+  const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
   TFLITE_DCHECK((accum_depth % 16) == 0);
   TFLITE_DCHECK((output_depth % 4) == 0);
 
@@ -738,8 +921,8 @@
         // (16-bit, typically 3 integer bits) fixed-point format. The quantized
         // multiplier and shift here have been pre-computed offline
         // (e.g. by toco).
-        acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
-                                            -output_shift);
+        acc =
+            MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
         // Saturate, cast to int16, and store to output array.
         acc = std::max(acc, output_activation_min);
         acc = std::min(acc, output_activation_max);
@@ -790,7 +973,7 @@
           // quantized multiplier and shift here have been pre-computed offline
           // (e.g. by toco).
           acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
-                                              -output_shift);
+                                              output_shift);
           // Saturate, cast to int16, and store to output array.
           acc = std::max(acc, output_activation_min);
           acc = std::min(acc, output_activation_max);
@@ -804,6 +987,30 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void ShuffledFullyConnected(
+    const uint8* input_data, const Dims<4>& input_dims,
+    const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
+    const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
+    int output_shift, int32 output_activation_min, int32 output_activation_max,
+    int16* output_data, const Dims<4>& output_dims,
+    uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+  tflite::FullyConnectedParams op_params;
+  op_params.output_multiplier = output_multiplier;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+  op_params.output_shift = kReverseShift * output_shift;
+  op_params.quantized_activation_min = output_activation_min;
+  op_params.quantized_activation_max = output_activation_max;
+
+  ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
+                         DimsToShape(weights_dims), shuffled_weights_data,
+                         DimsToShape(bias_dims), bias_data,
+                         DimsToShape(output_dims), output_data,
+                         shuffled_input_workspace_data, gemm_context);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
@@ -950,6 +1157,7 @@
     *output_inv_sqrt <<= -*output_shift;
     *output_shift = 0;
   }
+  // Convert right shift (right is positive) to left shift.
   *output_shift *= kReverseShift;
 }
 
@@ -1708,7 +1916,7 @@
                                const float* input2_data,
                                const RuntimeShape& output_shape,
                                float* output_data) {
-  gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float");
+  gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/float");
   NdArrayDesc<4> desc1;
   NdArrayDesc<4> desc2;
   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -1749,7 +1957,7 @@
                                const uint8* input2_data,
                                const RuntimeShape& output_shape,
                                uint8* output_data) {
-  gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8");
+  gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/uint8");
   NdArrayDesc<4> desc1;
   NdArrayDesc<4> desc2;
   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -1813,7 +2021,7 @@
                                const int32* input2_data,
                                const RuntimeShape& output_shape,
                                int32* output_data) {
-  gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32");
+  gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/int32");
   NdArrayDesc<4> desc1;
   NdArrayDesc<4> desc2;
   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -1853,7 +2061,7 @@
                         const RuntimeShape& input1_shape, const T* input1_data,
                         const RuntimeShape& input2_shape, const T* input2_data,
                         const RuntimeShape& output_shape, T* output_data) {
-  gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/templated");
+  gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/templated");
   NdArrayDesc<4> desc1;
   NdArrayDesc<4> desc2;
   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -2204,6 +2412,125 @@
                             output_data, output_dims);
 }
 
+inline void LstmCell(
+    const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+    const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
+    const float* prev_activ_data, const RuntimeShape& weights_shape,
+    const float* weights_data, const RuntimeShape& unextended_bias_shape,
+    const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
+    const float* prev_state_data,
+    const RuntimeShape& unextended_output_state_shape, float* output_state_data,
+    const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
+    const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
+    const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
+  TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+  const RuntimeShape input_shape =
+      RuntimeShape::ExtendedShape(4, unextended_input_shape);
+  const RuntimeShape prev_activ_shape =
+      RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+  const RuntimeShape bias_shape =
+      RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+  const RuntimeShape prev_state_shape =
+      RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+  const RuntimeShape output_state_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+  const RuntimeShape output_activ_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+  const RuntimeShape concat_temp_shape =
+      RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+  const RuntimeShape activ_temp_shape =
+      RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+  TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
+  const int weights_dim_count = weights_shape.DimensionsCount();
+  const int batches =
+      MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
+                  output_state_shape, 0, output_activ_shape, 0);
+  const int height =
+      MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
+                  output_state_shape, 1, output_activ_shape, 1);
+  const int width =
+      MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
+                  output_state_shape, 2, output_activ_shape, 2);
+  const int input_depth = input_shape.Dims(3);
+  const int prev_activ_depth = prev_activ_shape.Dims(3);
+  const int total_input_depth = prev_activ_depth + input_depth;
+  TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+                   total_input_depth);
+  TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
+  const int intern_activ_depth =
+      MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+  TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+                   intern_activ_depth * total_input_depth);
+  TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
+  const int output_depth =
+      MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+                  3, output_activ_shape, 3);
+  TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
+
+  // Concatenate prev_activ and input data together
+  std::vector<float const*> concat_input_arrays_data;
+  std::vector<RuntimeShape const*> concat_input_arrays_shapes;
+  concat_input_arrays_data.push_back(input_data);
+  concat_input_arrays_data.push_back(prev_activ_data);
+  concat_input_arrays_shapes.push_back(&input_shape);
+  concat_input_arrays_shapes.push_back(&prev_activ_shape);
+  tflite::ConcatenationParams concat_params;
+  concat_params.axis = 3;
+  concat_params.inputs_count = concat_input_arrays_data.size();
+  Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
+                &(concat_input_arrays_data[0]), concat_temp_shape,
+                concat_temp_data);
+
+  // Fully connected
+  tflite::FullyConnectedParams fc_params;
+  fc_params.float_activation_min = std::numeric_limits<float>::lowest();
+  fc_params.float_activation_max = std::numeric_limits<float>::max();
+  FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
+                 weights_data, bias_shape, bias_data, activ_temp_shape,
+                 activ_temp_data);
+
+  // Memory state update (the LSTM "guts")
+  for (int b = 0; b < batches; ++b) {
+    for (int w = 0; w < width; ++w) {
+      for (int h = 0; h < height; ++h) {
+        for (int c = 0; c < output_depth; ++c) {
+          const float input_gate =
+              1.f /
+              (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+                                                      0 * output_depth + c)]));
+          const float new_input = std::tanh(activ_temp_data[Offset(
+              activ_temp_shape, b, h, w, 1 * output_depth + c)]);
+          const float forget_gate =
+              1.f /
+              (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+                                                      2 * output_depth + c)]));
+          const float output_gate =
+              1.f /
+              (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+                                                      3 * output_depth + c)]));
+          const float new_state =
+              input_gate * new_input +
+              forget_gate *
+                  prev_state_data[Offset(prev_state_shape, b, h, w, c)];
+          output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
+          output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
+              output_gate * std::tanh(new_state);
+        }
+      }
+    }
+  }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
                      const float* prev_activ_data,
                      const Dims<4>& prev_activ_dims, const float* weights_data,
@@ -2214,77 +2541,17 @@
                      const Dims<4>& output_activ_dims, float* concat_temp_data,
                      const Dims<4>& concat_temp_dims, float* activ_temp_data,
                      const Dims<4>& activ_temp_dims) {
-  const int batches =
-      MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3,
-                        output_state_dims, 3, output_activ_dims, 3);
-  const int height =
-      MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2,
-                        output_state_dims, 2, output_activ_dims, 2);
-  const int width =
-      MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1,
-                        output_state_dims, 1, output_activ_dims, 1);
-  TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
-  TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
-  const int input_depth = ArraySize(input_dims, 0);
-  const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
-  const int total_input_depth = prev_activ_depth + input_depth;
-  TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
-  TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
-                  1);
-  const int intern_activ_depth =
-      MatchingArraySize(weights_dims, 1, bias_dims, 0);
-  TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
-  const int output_depth =
-      MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
-                        output_state_dims, 0, output_activ_dims, 0);
-  TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
+  tflite::LstmCellParams op_params;
+  // Float LSTM cell does not need parameters to be set: leave untouched.
 
-  // Concatenate prev_activ and input data together
-  std::vector<float const*> concat_input_arrays_data;
-  std::vector<Dims<4> const*> concat_input_arrays_dims;
-  concat_input_arrays_data.push_back(input_data);
-  concat_input_arrays_data.push_back(prev_activ_data);
-  concat_input_arrays_dims.push_back(&input_dims);
-  concat_input_arrays_dims.push_back(&prev_activ_dims);
-  Concatenation<FusedActivationFunctionType::kNone, float>(
-      0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
-      concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
-
-  // Fully connected
-  FullyConnected<FusedActivationFunctionType::kNone>(
-      concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
-      bias_dims, activ_temp_data, activ_temp_dims);
-
-  // Memory state update (the LSTM "guts")
-  for (int b = 0; b < batches; ++b) {
-    for (int w = 0; w < width; ++w) {
-      for (int h = 0; h < height; ++h) {
-        for (int c = 0; c < output_depth; ++c) {
-          const float input_gate =
-              1.f /
-              (1.f + std::exp(-activ_temp_data[Offset(
-                         activ_temp_dims, 0 * output_depth + c, w, h, b)]));
-          const float new_input = std::tanh(activ_temp_data[Offset(
-              activ_temp_dims, 1 * output_depth + c, w, h, b)]);
-          const float forget_gate =
-              1.f /
-              (1.f + std::exp(-activ_temp_data[Offset(
-                         activ_temp_dims, 2 * output_depth + c, w, h, b)]));
-          const float output_gate =
-              1.f /
-              (1.f + std::exp(-activ_temp_data[Offset(
-                         activ_temp_dims, 3 * output_depth + c, w, h, b)]));
-          const float new_state =
-              input_gate * new_input +
-              forget_gate *
-                  prev_state_data[Offset(prev_state_dims, c, w, h, b)];
-          output_state_data[Offset(output_state_dims, c, w, h, b)] = new_state;
-          output_activ_data[Offset(output_activ_dims, c, w, h, b)] =
-              output_gate * std::tanh(new_state);
-        }
-      }
-    }
-  }
+  LstmCell(op_params, DimsToShape(input_dims), input_data,
+           DimsToShape(prev_activ_dims), prev_activ_data,
+           DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
+           bias_data, DimsToShape(prev_state_dims), prev_state_data,
+           DimsToShape(output_state_dims), output_state_data,
+           DimsToShape(output_activ_dims), output_activ_data,
+           DimsToShape(concat_temp_dims), concat_temp_data,
+           DimsToShape(activ_temp_dims), activ_temp_data);
 }
 
 // Quantized LSTM cell implementation.
@@ -2372,52 +2639,90 @@
 // aiming for 16-bit fixed-point quantization of these internal nodes here.
 //
 template <int StateIntegerBits>
-void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
-              const uint8* prev_activ_data_uint8,
-              const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
-              const Dims<4>& weights_dims, const int32* bias_data_int32,
-              const Dims<4>& bias_dims, const int16* prev_state_data_int16,
-              const Dims<4>& prev_state_dims, int16* output_state_data_int16,
-              const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
-              const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
-              const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
-              const Dims<4>& activ_temp_dims, int32 weights_zero_point,
-              int32 accum_multiplier, int accum_shift,
-              gemmlowp::GemmContext* gemm_context) {
+inline void LstmCell(
+    const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+    const uint8* input_data_uint8,
+    const RuntimeShape& unextended_prev_activ_shape,
+    const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
+    const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
+    const int32* bias_data_int32,
+    const RuntimeShape& unextended_prev_state_shape,
+    const int16* prev_state_data_int16,
+    const RuntimeShape& unextended_output_state_shape,
+    int16* output_state_data_int16,
+    const RuntimeShape& unextended_output_activ_shape,
+    uint8* output_activ_data_uint8,
+    const RuntimeShape& unextended_concat_temp_shape,
+    uint8* concat_temp_data_uint8,
+    const RuntimeShape& unextended_activ_temp_shape,
+    int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
   (void)gemm_context;  // only used in optimized code.
+  int32 weights_zero_point = params.weights_zero_point;
+  int32 accum_multiplier = params.accum_multiplier;
+  int accum_shift = params.accum_shift;
+  TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+  const RuntimeShape input_shape =
+      RuntimeShape::ExtendedShape(4, unextended_input_shape);
+  const RuntimeShape prev_activ_shape =
+      RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+  const RuntimeShape bias_shape =
+      RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+  const RuntimeShape prev_state_shape =
+      RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+  const RuntimeShape output_state_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+  const RuntimeShape output_activ_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+  const RuntimeShape concat_temp_shape =
+      RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+  const RuntimeShape activ_temp_shape =
+      RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+  TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
 
   // Gather dimensions information, and perform consistency checks.
-  const int outer_size =
-      MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prev_state_dims,
-                              output_state_dims, output_activ_dims);
-  TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
-  TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
-  const int input_depth = ArraySize(input_dims, 0);
-  const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+  const int weights_dim_count = weights_shape.DimensionsCount();
+  const int outer_size = MatchingFlatSizeSkipDim(
+      input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
+      output_activ_shape);
+  const int input_depth = input_shape.Dims(3);
+  const int prev_activ_depth = prev_activ_shape.Dims(3);
   const int total_input_depth = prev_activ_depth + input_depth;
-  TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
-  TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
-                  1);
+  TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+                   total_input_depth);
   const int intern_activ_depth =
-      MatchingArraySize(weights_dims, 1, bias_dims, 0);
-  TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+      MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+  TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+                   intern_activ_depth * total_input_depth);
+  TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
+  TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
   const int output_depth =
-      MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
-                        output_state_dims, 0, output_activ_dims, 0);
-  TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
-  const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
+      MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+                  3, output_activ_shape, 3);
+  TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
+  const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
   const int fc_output_depth =
-      MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
-  const int fc_accum_depth = ArraySize(weights_dims, 0);
-  TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
+      MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
+  const int fc_accum_depth = total_input_depth;
+  TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
 
   // Depth-concatenate prev_activ and input data together.
   uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
                                               prev_activ_data_uint8};
-  Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
-  Concatenation<FusedActivationFunctionType::kNone, uint8>(
-      0, concat_input_arrays_data, concat_input_arrays_dims, 2,
-      concat_temp_data_uint8, concat_temp_dims);
+  const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
+                                                       &prev_activ_shape};
+  tflite::ConcatenationParams concat_params;
+  concat_params.axis = 3;
+  concat_params.inputs_count = 2;
+  Concatenation(concat_params, concat_input_arrays_shapes,
+                concat_input_arrays_data, concat_temp_shape,
+                concat_temp_data_uint8);
 
   // Implementation of the fully connected node inside the LSTM cell.
   // The operands are 8-bit integers, the accumulators are internally 32bit
@@ -2523,6 +2828,37 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <int StateIntegerBits>
+void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
+              const uint8* prev_activ_data_uint8,
+              const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
+              const Dims<4>& weights_dims, const int32* bias_data_int32,
+              const Dims<4>& bias_dims, const int16* prev_state_data_int16,
+              const Dims<4>& prev_state_dims, int16* output_state_data_int16,
+              const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
+              const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
+              const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
+              const Dims<4>& activ_temp_dims, int32 weights_zero_point,
+              int32 accum_multiplier, int accum_shift,
+              gemmlowp::GemmContext* gemm_context) {
+  tflite::LstmCellParams op_params;
+  op_params.weights_zero_point = weights_zero_point;
+  op_params.accum_multiplier = accum_multiplier;
+  op_params.accum_shift = accum_shift;
+
+  LstmCell<StateIntegerBits>(
+      op_params, DimsToShape(input_dims), input_data_uint8,
+      DimsToShape(prev_activ_dims), prev_activ_data_uint8,
+      DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
+      bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
+      DimsToShape(output_state_dims), output_state_data_int16,
+      DimsToShape(output_activ_dims), output_activ_data_uint8,
+      DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+      DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
+}
+
 template <typename Scalar>
 void Split(const SplitParams& params, const RuntimeShape& input_shape,
            const Scalar* input_data, const RuntimeShape* const* output_shapes,
@@ -2902,9 +3238,9 @@
   }
 }
 
-inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
-                    float beta, float* output_data,
-                    const RuntimeShape& output_shape) {
+inline void Softmax(const SoftmaxParams& params,
+                    const RuntimeShape& input_shape, const float* input_data,
+                    const RuntimeShape& output_shape, float* output_data) {
   const int trailing_dim = input_shape.DimensionsCount() - 1;
   const int outer_size =
       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
@@ -2923,21 +3259,33 @@
     // Compute sum.
     float sum = 0.f;
     for (int c = 0; c < depth; ++c) {
-      sum += std::exp((input_data[i * depth + c] - max) * beta);
+      sum += std::exp((input_data[i * depth + c] - max) * params.beta);
     }
 
     // Compute result.
     for (int c = 0; c < depth; ++c) {
       output_data[i * depth + c] =
-          std::exp((input_data[i * depth + c] - max) * beta) / sum;
+          std::exp((input_data[i * depth + c] - max) * params.beta) / sum;
     }
   }
 }
 
-inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
-                    int32 input_beta_multiplier, int32 input_beta_left_shift,
-                    int diff_min, uint8* output_data,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
+                    float beta, float* output_data,
                     const RuntimeShape& output_shape) {
+  SoftmaxParams params;
+  params.beta = beta;
+  Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Softmax(const SoftmaxParams& params,
+                    const RuntimeShape& input_shape, const uint8* input_data,
+                    const RuntimeShape& output_shape, uint8* output_data) {
+  const int32 input_beta_multiplier = params.input_multiplier;
+  const int32 input_beta_left_shift = params.input_left_shift;
+  const int diff_min = params.diff_min;
   // The representation chosen for the input to the exp() function is Q5.26.
   // We need to leave extra space since values that we skip might be as large as
   // -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -3015,8 +3363,22 @@
   }
 }
 
-inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
-                       float* output_data, const RuntimeShape& output_shape) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
+                    int32 input_beta_multiplier, int32 input_beta_left_shift,
+                    int diff_min, uint8* output_data,
+                    const RuntimeShape& output_shape) {
+  SoftmaxParams params;
+  params.input_multiplier = input_beta_multiplier;
+  params.input_left_shift = input_beta_left_shift;
+  params.diff_min = diff_min;
+  Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void LogSoftmax(const SoftmaxParams& params,
+                       const RuntimeShape& input_shape, const float* input_data,
+                       const RuntimeShape& output_shape, float* output_data) {
   const int trailing_dim = input_shape.DimensionsCount() - 1;
   const int outer_size =
       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
@@ -3046,6 +3408,15 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+                       float* output_data, const RuntimeShape& output_shape) {
+  SoftmaxParams params;
+  // No params currently used for float LogSoftmax.
+  LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
 // Although currently the name of this function says that it cannot handle
 // values less than 1, in practice it can handle as low as 1/x_max, where
 // x_max is the largest representable input.  In other words, the output range
@@ -3161,16 +3532,19 @@
       input_val);
 }
 
-inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
-                       int32 input_multiplier, int32 input_left_shift,
-                       int32 reverse_scaling_divisor,
-                       int32 reverse_scaling_right_shift, int diff_min,
-                       uint8* output_data, const RuntimeShape& output_shape) {
+inline void LogSoftmax(const SoftmaxParams& params,
+                       const RuntimeShape& input_shape, const uint8* input_data,
+                       const RuntimeShape& output_shape, uint8* output_data) {
+  const int32 input_multiplier = params.input_multiplier;
+  const int32 input_left_shift = params.input_left_shift;
+  const int32 reverse_scaling_divisor = params.reverse_scaling_divisor;
+  const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift;
+  const int diff_min = params.diff_min;
   // The representation chosen for the input to the exp() function is Q5.26.
-  // We need to leave extra space since values that we skip might be as large as
-  // -32 before multiplying by input_beta_multiplier, and therefore as large as
-  // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
-  // accumulation, but exp(-16) definitely is.
+  // We need to leave extra space since values that we skip might be as large
+  // as -32 before multiplying by input_beta_multiplier, and therefore as
+  // large as -16 afterwards.  Note that exp(-8) is definitely not
+  // insignificant to accumulation, but exp(-16) definitely is.
   static constexpr int kScaledDiffIntegerBits = 5;
   static constexpr int kAccumulationIntegerBits = 12;
   static constexpr int kOutputIntegerBits = 4;
@@ -3222,7 +3596,7 @@
         std::max(diff_min - 1,  // Note use of > below instead of >= above.
                  MultiplyByQuantizedMultiplierSmallerThanOneExp(
                      rescaled_diff_min, reverse_scaling_divisor,
-                     kReverseShift * reverse_scaling_right_shift));
+                     -reverse_scaling_right_shift));
 
     for (int c = 0; c < depth; ++c) {
       int32 input_diff =
@@ -3247,6 +3621,22 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
+                       int32 input_multiplier, int32 input_left_shift,
+                       int32 reverse_scaling_divisor,
+                       int32 reverse_scaling_right_shift, int diff_min,
+                       uint8* output_data, const RuntimeShape& output_shape) {
+  SoftmaxParams params;
+  params.input_multiplier = input_multiplier;
+  params.input_left_shift = input_left_shift;
+  params.reverse_scaling_divisor = reverse_scaling_divisor;
+  params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+  params.diff_min = diff_min;
+  LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
 inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
                      const RuntimeShape& output_shape, float* output_data) {
   const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3258,10 +3648,22 @@
   }
 }
 
-inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
-                     int32 input_zero_point, int32 input_range_radius,
-                     int32 input_multiplier, int input_left_shift,
-                     uint8* output_data, const RuntimeShape& output_shape) {
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+                     const float* input_data, const RuntimeShape& output_shape,
+                     float* output_data) {
+  // Drop params: not needed.
+  Logistic(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+                     const RuntimeShape& input_shape, const uint8* input_data,
+                     const RuntimeShape& output_shape, uint8* output_data) {
+  const int32 input_zero_point = params.input_zero_point;
+  const int32 input_range_radius = params.input_range_radius;
+  const int32 input_multiplier = params.input_multiplier;
+  const int input_left_shift = params.input_left_shift;
   const int flat_size = MatchingFlatSize(input_shape, output_shape);
 
   for (int i = 0; i < flat_size; i++) {
@@ -3296,7 +3698,22 @@
   }
 }
 
-inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
+                     int32 input_zero_point, int32 input_range_radius,
+                     int32 input_multiplier, int input_left_shift,
+                     uint8* output_data, const RuntimeShape& output_shape) {
+  LogisticParams params;
+  params.input_zero_point = input_zero_point;
+  params.input_range_radius = input_range_radius;
+  params.input_multiplier = input_multiplier;
+  params.input_left_shift = input_left_shift;
+  Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+                     const RuntimeShape& input_shape, const int16* input_data,
                      const RuntimeShape& output_shape, int16* output_data) {
   const int flat_size = MatchingFlatSize(input_shape, output_shape);
 
@@ -3314,6 +3731,15 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+                     const RuntimeShape& output_shape, int16* output_data) {
+  LogisticParams params;
+  // No params currently needed by int16 Logistic.
+  Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
 inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
                  const RuntimeShape& output_shape, float* output_data) {
   const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3325,10 +3751,22 @@
   }
 }
 
-inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
-                 int32 input_zero_point, int32 input_range_radius,
-                 int32 input_multiplier, int input_left_shift,
-                 uint8* output_data, const RuntimeShape& output_shape) {
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+                 const float* input_data, const RuntimeShape& output_shape,
+                 float* output_data) {
+  // Drop params: not needed.
+  Tanh(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+                 const uint8* input_data, const RuntimeShape& output_shape,
+                 uint8* output_data) {
+  const int32 input_zero_point = params.input_zero_point;
+  const int32 input_range_radius = params.input_range_radius;
+  const int32 input_multiplier = params.input_multiplier;
+  const int input_left_shift = params.input_left_shift;
   const int32 output_zero_point = 128;
   const int flat_size = MatchingFlatSize(input_shape, output_shape);
 
@@ -3365,9 +3803,24 @@
   }
 }
 
-inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
-                 int input_left_shift, int16* output_data,
-                 const RuntimeShape& output_shape) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
+                 int32 input_zero_point, int32 input_range_radius,
+                 int32 input_multiplier, int input_left_shift,
+                 uint8* output_data, const RuntimeShape& output_shape) {
+  TanhParams params;
+  params.input_zero_point = input_zero_point;
+  params.input_range_radius = input_range_radius;
+  params.input_multiplier = input_multiplier;
+  params.input_left_shift = input_left_shift;
+  Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+                 const int16* input_data, const RuntimeShape& output_shape,
+                 int16* output_data) {
+  const int input_left_shift = params.input_left_shift;
   // Support for shifts is limited until we have a parameterized version of
   // SaturatingRoundingMultiplyByPOT().
   TFLITE_DCHECK_GE(input_left_shift, 0);
@@ -3398,6 +3851,16 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
+                 int input_left_shift, int16* output_data,
+                 const RuntimeShape& output_shape) {
+  TanhParams params;
+  params.input_left_shift = input_left_shift;
+  Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
 inline void Dequantize(const tflite::DequantizationParams& op_params,
                        const RuntimeShape& input_shape, const uint8* input_data,
                        const RuntimeShape& output_shape, float* output_data) {
@@ -4270,6 +4733,16 @@
   }
 }
 
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+                    const RuntimeShape&, const T* input2_data,
+                    const RuntimeShape& output_shape, T* output_data) {
+  // Drop shape of second input: not needed.
+  Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
 template <typename T>
 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
              const T* input2_data, const RuntimeShape& output_shape,
@@ -4282,6 +4755,16 @@
   }
 }
 
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+                    const RuntimeShape&, const T* input2_data,
+                    const RuntimeShape& output_shape, T* output_data) {
+  // Drop shape of second input: not needed.
+  Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
 template <typename T, typename Op>
 void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
                                    const T* input1_data,
@@ -4357,50 +4840,105 @@
             std::greater<T1>());
 }
 
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T1, typename T2, typename T3>
+inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
+                   const RuntimeShape& input2_shape, const T3* input2_data,
+                   const RuntimeShape& output_shape, T2* output_data) {
+  // Drop shape of second input: not needed.
+  ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
 template <typename T>
-void Transpose(const T* input, const Dims<4>& input_dims, T* output,
-               const Dims<4>& output_dims, const int* permuted_axes) {
+void Transpose(const TransposeParams& params,
+               const RuntimeShape& unextended_input_shape, const T* input_data,
+               const RuntimeShape& unextended_output_shape, T* output_data) {
+  const int unextended_output_size = unextended_output_shape.DimensionsCount();
+  TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_size, 4);
+  TFLITE_DCHECK_EQ(unextended_output_size, params.perm_count);
+  const RuntimeShape input_shape =
+      RuntimeShape::ExtendedShape(4, unextended_input_shape);
+  const RuntimeShape output_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_shape);
+  const int input_ext_size = 4 - unextended_input_shape.DimensionsCount();
+  const int output_ext_size = 4 - unextended_output_size;
+
+  // The perm data is extended to match the output, each index incremented by
+  // the amount of front padding of the input shape.
+  int extended_perm[4];
+  for (int i = 0; i < output_ext_size; ++i) {
+    extended_perm[i] = i;
+  }
+  for (int i = 0; i < unextended_output_size; ++i) {
+    extended_perm[i + output_ext_size] = params.perm[i] + input_ext_size;
+  }
+
   int out_sizes[4];
   // Compute the inverse permutation array so we can do an output centered
   // transpose. Also, check to make sure output_dims is matching input_dims.
   for (int k = 0; k < 4; k++) {
-    out_sizes[k] =
-        MatchingArraySize(input_dims, permuted_axes[k], output_dims, k);
+    out_sizes[k] = MatchingDim(input_shape, extended_perm[k], output_shape, k);
   }
 
   // Naive transpose loop (iterate on output index and compute input index).
   int o[4];  // loop index (on output).
   int i[4];
   for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) {
-    i[permuted_axes[3]] = o[3];
+    i[extended_perm[3]] = o[3];
     for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) {
-      i[permuted_axes[2]] = o[2];
+      i[extended_perm[2]] = o[2];
       for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) {
-        i[permuted_axes[1]] = o[1];
+        i[extended_perm[1]] = o[1];
         for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) {
-          i[permuted_axes[0]] = o[0];
-          output[Offset(output_dims, o)] = input[Offset(input_dims, i)];
+          i[extended_perm[0]] = o[0];
+          output_data[Offset(output_shape, o)] =
+              input_data[Offset(input_shape, i)];
         }
       }
     }
   }
 }
 
-inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
-                          const float* filter_data, const Dims<4>& filter_dims,
-                          int stride_width, int stride_height, int pad_width,
-                          int pad_height, float* output_data,
-                          const Dims<4>& output_dims, float* /*im2col_data*/,
-                          const Dims<4>& /*im2col_dims*/) {
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
-  const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
-  const int input_height = ArraySize(input_dims, 2);
-  const int input_width = ArraySize(input_dims, 1);
-  const int filter_height = ArraySize(filter_dims, 2);
-  const int filter_width = ArraySize(filter_dims, 1);
-  const int output_height = ArraySize(output_dims, 2);
-  const int output_width = ArraySize(output_dims, 1);
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+void Transpose(const T* input, const Dims<4>& input_dims, T* output,
+               const Dims<4>& output_dims, const int* permuted_axes) {
+  TransposeParams params;
+  params.perm_count = 4;
+  for (int i = 0; i < 4; ++i) {
+    params.perm[i] = 3 - permuted_axes[3 - i];
+  }
+  Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
+            output);
+}
+
+inline void TransposeConv(
+    const ConvParams& params, const RuntimeShape& input_shape,
+    const float* input_data, const RuntimeShape& filter_shape,
+    const float* filter_data, const RuntimeShape& output_shape,
+    float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+  (void)im2col_data;   // only used in optimized code.
+  (void)im2col_shape;  // only used in optimized code.
+
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+  const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
 
   // Although transpose convolution simplifies to convolution with transposed
   // weights for strides of 1, non-unitary striding complicates matters. To
@@ -4409,7 +4947,7 @@
   // computing their influence on the output, rather than looping through the
   // output elements in the typical "gather" access pattern of a conv. We
   // therefore must initialize the output array to zero.
-  const int num_elements = FlatSize(output_dims);
+  const int num_elements = output_shape.FlatSize();
   for (int i = 0; i < num_elements; i++) {
     output_data[i] = 0.0f;
   }
@@ -4432,13 +4970,14 @@
                 // We cannot accumulate out of bounds
                 if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
                     (out_y < output_height)) {
-                  float input_value = input_data[Offset(input_dims, in_channel,
-                                                        in_x, in_y, batch)];
+                  float input_value = input_data[Offset(
+                      input_shape, batch, in_y, in_x, in_channel)];
                   float filter_value =
-                      filter_data[Offset(filter_dims, in_channel, filter_x,
-                                         filter_y, out_channel)];
-                  output_data[Offset(output_dims, out_channel, out_x, out_y,
-                                     batch)] += input_value * filter_value;
+                      filter_data[Offset(filter_shape, out_channel, filter_y,
+                                         filter_x, in_channel)];
+                  output_data[Offset(output_shape, batch, out_y, out_x,
+                                     out_channel)] +=
+                      input_value * filter_value;
                 }
               }
             }
@@ -4449,6 +4988,27 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
+                          const float* filter_data, const Dims<4>& filter_dims,
+                          int stride_width, int stride_height, int pad_width,
+                          int pad_height, float* output_data,
+                          const Dims<4>& output_dims, float* im2col_data,
+                          const Dims<4>& im2col_dims) {
+  tflite::ConvParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
+
+  TransposeConv(op_params, DimsToShape(input_dims), input_data,
+                DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
+                output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
 template <typename T>
 inline bool EqualFn(T lhs, T rhs) {
   return lhs == rhs;
@@ -4559,9 +5119,11 @@
   op_params.left_shift = left_shift;
   op_params.input1_offset = input1_offset;
   op_params.input1_multiplier = input1_multiplier;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
   op_params.input1_shift = kReverseShift * input1_shift;
   op_params.input2_offset = input2_offset;
   op_params.input2_multiplier = input2_multiplier;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
   op_params.input2_shift = kReverseShift * input2_shift;
 
   ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
@@ -4693,9 +5255,11 @@
   op_params.left_shift = left_shift;
   op_params.input1_offset = input1_offset;
   op_params.input1_multiplier = input1_multiplier;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
   op_params.input1_shift = kReverseShift * input1_shift;
   op_params.input2_offset = input2_offset;
   op_params.input2_multiplier = input2_multiplier;
+  // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
   op_params.input2_shift = kReverseShift * input2_shift;
 
   BroadcastComparison4DSlowWithScaling<T, F>(
@@ -4799,47 +5363,81 @@
 #undef TFLITE_COMPARISON_OP
 
 template <typename D, typename T>
-inline void Select(const D* input_condition_data,
-                   const Dims<4>& input_condition_dims, const T* input_x_data,
-                   const Dims<4>& input_x_dims, const T* input_y_data,
-                   const Dims<4>& input_y_dims, T* output_data,
-                   const Dims<4>& output_dims) {
-  const int64_t flatsize =
-      MatchingFlatSize(input_x_dims, input_y_dims, output_dims);
+void Select(const RuntimeShape& input_condition_shape,
+            const D* input_condition_data, const RuntimeShape& input_x_shape,
+            const T* input_x_data, const RuntimeShape& input_y_shape,
+            const T* input_y_data, const RuntimeShape& output_shape,
+            T* output_data) {
+  const int64_t flatsize = MatchingFlatSize(
+      input_condition_shape, input_x_shape, input_y_shape, output_shape);
   for (int64_t i = 0; i < flatsize; ++i) {
     output_data[i] =
         input_condition_data[i] ? input_x_data[i] : input_y_data[i];
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 template <typename D, typename T>
-inline void RankOneSelect(const D* input_condition_data,
-                          const Dims<4>& input_condition_dims,
-                          const T* input_x_data, const Dims<4>& input_x_dims,
-                          const T* input_y_data, const Dims<4>& input_y_dims,
-                          T* output_data, const Dims<4>& output_dims) {
-  const int64_t rank = MatchingArraySize(input_condition_dims, 0, input_x_dims,
-                                         3, input_y_dims, 3, output_dims, 3);
+inline void Select(const D* input_condition_data,
+                   const Dims<4>& input_condition_dims, const T* input_x_data,
+                   const Dims<4>& input_x_dims, const T* input_y_data,
+                   const Dims<4>& input_y_dims, T* output_data,
+                   const Dims<4>& output_dims) {
+  Select(DimsToShape(input_condition_dims), input_condition_data,
+         DimsToShape(input_x_dims), input_x_data, DimsToShape(input_y_dims),
+         input_y_data, DimsToShape(output_dims), output_data);
+}
+
+template <typename D, typename T>
+void RankOneSelect(const RuntimeShape& input_condition_shape,
+                   const D* input_condition_data,
+                   const RuntimeShape& input_x_shape, const T* input_x_data,
+                   const RuntimeShape& input_y_shape, const T* input_y_data,
+                   const RuntimeShape& output_shape, T* output_data) {
+  const int64_t outer_size = input_condition_shape.FlatSize();
+  TFLITE_DCHECK_EQ(
+      MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
+      outer_size);
   const int64_t inner_size =
-      MatchingFlatSizeSkipDim(input_x_dims, 3, input_y_dims, output_dims);
+      MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
 
   int64_t offset = 0;
-  for (int64_t i = 0; i < rank; i++) {
+  for (int64_t i = 0; i < outer_size; i++) {
     const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
     memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
     offset += inner_size;
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename D, typename T>
+inline void RankOneSelect(const D* input_condition_data,
+                          const Dims<4>& input_condition_dims,
+                          const T* input_x_data, const Dims<4>& input_x_dims,
+                          const T* input_y_data, const Dims<4>& input_y_dims,
+                          T* output_data, const Dims<4>& output_dims) {
+  RankOneSelect(DimsToShape(input_condition_dims), input_condition_data,
+                DimsToShape(input_x_dims), input_x_data,
+                DimsToShape(input_y_dims), input_y_data,
+                DimsToShape(output_dims), output_data);
+}
+
 // For easy implementation, the indices is always a vector of size-4 vectors.
 template <typename T, typename TI>
 inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
-                          const T* values, T default_value, T* output_data,
-                          const Dims<4>& output_dims, bool value_is_scalar) {
+                          const T* values, T default_value,
+                          bool value_is_scalar,
+                          const RuntimeShape& unextended_output_shape,
+                          T* output_data) {
+  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+  const RuntimeShape output_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_shape);
   const int value_count = indices.size();
 
   // First fill the output_data with default value.
-  const int num_elements = FlatSize(output_dims);
+  const int num_elements = output_shape.FlatSize();
   for (int i = 0; i < num_elements; ++i) {
     output_data[i] = default_value;
   }
@@ -4851,8 +5449,8 @@
       const std::vector<TI>& index = indices[i];
       TFLITE_DCHECK_EQ(index.size(), 4);
       const T value = *values;  // just use the first value.
-      output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] =
-          value;
+      output_data[Offset(output_shape, index[0], index[1], index[2],
+                         index[3])] = value;
     }
     return;
   }
@@ -4862,11 +5460,21 @@
     const std::vector<TI>& index = indices[i];
     TFLITE_DCHECK_EQ(index.size(), 4);
     const T value = values[i];
-    output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] =
+    output_data[Offset(output_shape, index[0], index[1], index[2], index[3])] =
         value;
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, typename TI>
+inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
+                          const T* values, T default_value, T* output_data,
+                          const Dims<4>& output_dims, bool value_is_scalar) {
+  SparseToDense(indices, values, default_value, value_is_scalar,
+                DimsToShape(output_dims), output_data);
+}
+
 template <typename T>
 inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
                 const RuntimeShape& input2_shape, const T* input2_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.cc b/tensorflow/contrib/lite/kernels/internal/test_util.cc
index 9b1fd9b..5ae4b19 100644
--- a/tensorflow/contrib/lite/kernels/internal/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/internal/test_util.cc
@@ -43,17 +43,21 @@
 
 // this is a copied from an internal function in propagate_fixed_sizes.cc
 bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
-                      int filter_height, int stride, PaddingType padding_type,
+                      int filter_height, int stride, int dilation_width_factor,
+                      int dilation_height_factor, PaddingType padding_type,
                       Dims<4>* output_dims, int* pad_width, int* pad_height) {
   const int input_width = ArraySize(input_dims, 1);
   const int input_height = ArraySize(input_dims, 2);
   const int batch = ArraySize(input_dims, 3);
 
+  int dilated_filter_width = dilation_width_factor * (filter_width - 1) + 1;
+  int dilated_filter_height = dilation_height_factor * (filter_height - 1) + 1;
+
   int output_height = 0;
   int output_width = 0;
   if (padding_type == PaddingType::kValid) {
-    output_height = (input_height + stride - filter_height) / stride;
-    output_width = (input_width + stride - filter_width) / stride;
+    output_height = (input_height + stride - dilated_filter_height) / stride;
+    output_width = (input_width + stride - dilated_filter_width) / stride;
   } else if (padding_type == PaddingType::kSame) {
     output_height = (input_height + stride - 1) / stride;
     output_width = (input_width + stride - 1) / stride;
@@ -65,9 +69,13 @@
     return false;
   }
 
-  *pad_height =
-      ((output_height - 1) * stride + filter_height - input_height) / 2;
-  *pad_width = ((output_width - 1) * stride + filter_width - input_width) / 2;
+  *pad_height = std::max(
+      0, ((output_height - 1) * stride + dilated_filter_height - input_height) /
+             2);
+  *pad_width = std::max(
+      0,
+      ((output_width - 1) * stride + dilated_filter_width - input_width) / 2);
+
   *output_dims =
       MakeDimsForInference(output_depth, output_width, output_height, batch);
   return true;
diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.h b/tensorflow/contrib/lite/kernels/internal/test_util.h
index 26078ce..cb6d8b1 100644
--- a/tensorflow/contrib/lite/kernels/internal/test_util.h
+++ b/tensorflow/contrib/lite/kernels/internal/test_util.h
@@ -31,7 +31,8 @@
 
 // Computes output and padding dimensions.
 bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
-                      int filter_height, int stride, PaddingType padding_type,
+                      int filter_height, int stride, int dilation_width_factor,
+                      int dilation_height_factor, PaddingType padding_type,
                       Dims<4>* output_dims, int* pad_width, int* pad_height);
 
 // Returns a mt19937 random engine.
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 023707d..070ad4e 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -18,6 +18,12 @@
 #include <cstring>
 #include <iterator>
 
+// TODO: Remove once AOSP has external/absl setup.
+#if __ANDROID__
+#define ABSL_DEPRECATED(x)
+#else
+#include "absl/base/macros.h"
+#endif  // __ANDROID__
 #include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
 
 namespace tflite {
@@ -179,12 +185,15 @@
       dims_[i] = val;
     }
   }
+
   inline int32* DimsData() {
     return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
   }
   inline const int32* DimsData() const {
     return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
   }
+  // The caller must ensure that the shape is no bigger than 4-D.
+  inline const int32* DimsDataUpTo4D() const { return dims_; }
 
   inline void Resize(int dimensions_count) {
     if (size_ > kMaxSmallSize) {
@@ -283,6 +292,12 @@
   return result;
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
+  return RuntimeShape(
+      {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
 // Gets next index to iterate through a multidimensional array.
 inline bool NextIndex(const int num_dims, const int* dims, int* current) {
   if (num_dims == 0) {
@@ -340,11 +355,12 @@
 }
 
 inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
-  TFLITE_DCHECK(i0 >= 0 && i0 < shape.Dims(0));
-  TFLITE_DCHECK(i1 >= 0 && i1 < shape.Dims(1));
-  TFLITE_DCHECK(i2 >= 0 && i2 < shape.Dims(2));
-  TFLITE_DCHECK(i3 >= 0 && i3 < shape.Dims(3));
-  const int* dims_data = shape.DimsData();
+  TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
+  const int* dims_data = shape.DimsDataUpTo4D();
+  TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
+  TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
+  TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
+  TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
   return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
 }
 
@@ -361,6 +377,10 @@
   return Offset(dims, index[0], index[1], index[2], index[3]);
 }
 
+inline int Offset(const RuntimeShape& shape, int* index) {
+  return Offset(shape, index[0], index[1], index[2], index[3]);
+}
+
 // Get array size, DCHECKing that the dim index is in range.
 //
 // Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
@@ -410,7 +430,7 @@
   return flat_size;
 }
 
-// Deprecated. Prefer FlatSize.
+ABSL_DEPRECATED("Prefer FlatSize.")
 inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
   return FlatSize(dims);
 }
@@ -760,7 +780,10 @@
 struct DepthwiseParams {
   PaddingType padding_type;
   PaddingValues padding_values;
-  int16 stride;
+  int16 stride_width;
+  int16 stride_height;
+  int16 dilation_width_factor;
+  int16 dilation_height_factor;
   int16 depth_multiplier;
   // uint8 inference params.
   // TODO(b/65838351): Use smaller types if appropriate.
@@ -885,8 +908,8 @@
   // for LogSoftmax.
   double beta;
   // uint8 inference params.  Used even when beta defaults to 1.0.
-  int32 input_beta_multiplier;
-  int32 input_beta_left_shift;
+  int32 input_multiplier;
+  int32 input_left_shift;
   // Reverse scaling is only used by LogSoftmax.
   int32 reverse_scaling_divisor;
   int32 reverse_scaling_right_shift;
@@ -936,6 +959,11 @@
   int input_left_shift;
 };
 
+struct TransposeParams {
+  int8 perm_count;
+  int32 perm[4];
+};
+
 template <typename P>
 inline void SetActivationParams(float min, float max, P* params) {
   params->float_activation_min = min;
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 1debf10..3e80638 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -119,6 +119,7 @@
 TfLiteRegistration* Register_UNPACK();
 TfLiteRegistration* Register_FLOOR_DIV();
 TfLiteRegistration* Register_SQUARE();
+TfLiteRegistration* Register_ZEROS_LIKE();
 
 TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
   context->ReportError(
@@ -157,7 +158,9 @@
   AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D());
   AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D());
   AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
-  AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
+  AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D(),
+             /* min_version */ 1,
+             /* max_version */ 2);
   AddBuiltin(BuiltinOperator_SVDF, Register_SVDF());
   AddBuiltin(BuiltinOperator_RNN, Register_RNN());
   AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
@@ -245,6 +248,7 @@
   AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
   AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
   AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
+  AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE());
 
 #if 0
   // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
diff --git a/tensorflow/contrib/lite/kernels/zeros_like.cc b/tensorflow/contrib/lite/kernels/zeros_like.cc
new file mode 100644
index 0000000..cce5240
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/zeros_like.cc
@@ -0,0 +1,73 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace zeros_like {
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+  output->type = input->type;
+
+  return context->ResizeTensor(context, output,
+                               TfLiteIntArrayCopy(input->dims));
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+  const int num_elements = NumElements(input);
+  switch (input->type) {
+    case kTfLiteInt64:
+      memset(GetTensorData<int64_t>(output), 0, num_elements * sizeof(int64_t));
+      break;
+    case kTfLiteInt32:
+      memset(GetTensorData<int32_t>(output), 0, num_elements * sizeof(int32_t));
+      break;
+    case kTfLiteFloat32:
+      memset(GetTensorData<float>(output), 0, num_elements * sizeof(float));
+      break;
+    default:
+      context->ReportError(context,
+                           "ZerosLike only currently supports int64, int32, "
+                           "and float32, got %d.",
+                           input->type);
+      return kTfLiteError;
+  }
+  return kTfLiteOk;
+}
+
+}  // namespace zeros_like
+
+TfLiteRegistration* Register_ZEROS_LIKE() {
+  static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+                                 zeros_like::Prepare, zeros_like::Eval};
+  return &r;
+}
+
+}  // namespace builtin
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/zeros_like_test.cc b/tensorflow/contrib/lite/kernels/zeros_like_test.cc
new file mode 100644
index 0000000..d3382d1
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/zeros_like_test.cc
@@ -0,0 +1,78 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class ZerosLikeOpModel : public SingleOpModel {
+ public:
+  explicit ZerosLikeOpModel(const TensorData& input) {
+    input_ = AddInput(input);
+    output_ = AddOutput(input);
+    SetBuiltinOp(BuiltinOperator_ZEROS_LIKE, BuiltinOptions_ZerosLikeOptions,
+                 CreateZerosLikeOptions(builder_).Union());
+    BuildInterpreter({GetShape(input_)});
+  }
+
+  int input() { return input_; }
+  int output() { return output_; }
+
+ protected:
+  int input_;
+  int output_;
+};
+
+TEST(ZerosLikeOpModel, ZerosLikeFloat) {
+  ZerosLikeOpModel m({TensorType_FLOAT32, {2, 3}});
+  m.PopulateTensor<float>(m.input(), {-2.0, -1.0, 0.0, 1.0, 2.0, 3.0});
+  m.Invoke();
+  EXPECT_THAT(m.ExtractVector<float>(m.output()),
+              ElementsAreArray({0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
+  EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({2, 3}));
+}
+
+TEST(ZerosLikeOpModel, ZerosLikeInt32) {
+  ZerosLikeOpModel m({TensorType_INT32, {1, 2, 2, 1}});
+  m.PopulateTensor<int32_t>(m.input(), {-2, -1, 0, 3});
+  m.Invoke();
+  EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
+              ElementsAreArray({0, 0, 0, 0}));
+  EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 2, 2, 1}));
+}
+
+TEST(ZerosLikeOpModel, ZerosLikeInt64) {
+  ZerosLikeOpModel m({TensorType_INT64, {1, 2, 2, 1}});
+  m.PopulateTensor<int64_t>(m.input(), {-2, -1, 0, 3});
+  m.Invoke();
+  EXPECT_THAT(m.ExtractVector<int64_t>(m.output()),
+              ElementsAreArray({0, 0, 0, 0}));
+  EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 2, 2, 1}));
+}
+
+}  // namespace
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::tflite::LogToStderr();
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/mutable_op_resolver.cc b/tensorflow/contrib/lite/mutable_op_resolver.cc
index 8ee63d2..a364043 100644
--- a/tensorflow/contrib/lite/mutable_op_resolver.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver.cc
@@ -30,10 +30,11 @@
 }
 
 void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
-                                   TfLiteRegistration* registration,
+                                   const TfLiteRegistration* registration,
                                    int min_version, int max_version) {
   for (int version = min_version; version <= max_version; ++version) {
     TfLiteRegistration new_registration = *registration;
+    new_registration.custom_name = nullptr;
     new_registration.builtin_code = op;
     new_registration.version = version;
     auto op_key = std::make_pair(op, version);
@@ -42,15 +43,27 @@
 }
 
 void MutableOpResolver::AddCustom(const char* name,
-                                  TfLiteRegistration* registration,
+                                  const TfLiteRegistration* registration,
                                   int min_version, int max_version) {
   for (int version = min_version; version <= max_version; ++version) {
     TfLiteRegistration new_registration = *registration;
     new_registration.builtin_code = BuiltinOperator_CUSTOM;
+    new_registration.custom_name = name;
     new_registration.version = version;
     auto op_key = std::make_pair(name, version);
     custom_ops_[op_key] = new_registration;
   }
 }
 
+void MutableOpResolver::AddAll(const MutableOpResolver& other) {
+  // map::insert does not replace existing elements, and map::insert_or_assign
+  // wasn't added until C++17.
+  for (const auto& other_builtin : other.builtins_) {
+    builtins_[other_builtin.first] = other_builtin.second;
+  }
+  for (const auto& other_custom_op : other.custom_ops_) {
+    custom_ops_[other_custom_op.first] = other_custom_op.second;
+  }
+}
+
 }  // namespace tflite
diff --git a/tensorflow/contrib/lite/mutable_op_resolver.h b/tensorflow/contrib/lite/mutable_op_resolver.h
index c319041..efd6cfa 100644
--- a/tensorflow/contrib/lite/mutable_op_resolver.h
+++ b/tensorflow/contrib/lite/mutable_op_resolver.h
@@ -57,10 +57,12 @@
   const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
                                    int version) const override;
   const TfLiteRegistration* FindOp(const char* op, int version) const override;
-  void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
-                  int min_version = 1, int max_version = 1);
-  void AddCustom(const char* name, TfLiteRegistration* registration,
+  void AddBuiltin(tflite::BuiltinOperator op,
+                  const TfLiteRegistration* registration, int min_version = 1,
+                  int max_version = 1);
+  void AddCustom(const char* name, const TfLiteRegistration* registration,
                  int min_version = 1, int max_version = 1);
+  void AddAll(const MutableOpResolver& other);
 
  private:
   typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
diff --git a/tensorflow/contrib/lite/mutable_op_resolver_test.cc b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
index db690ea..b70c703 100644
--- a/tensorflow/contrib/lite/mutable_op_resolver_test.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
@@ -36,6 +36,20 @@
   return &registration;
 }
 
+TfLiteStatus Dummy2Invoke(TfLiteContext* context, TfLiteNode* node) {
+  return kTfLiteOk;
+}
+
+TfLiteRegistration* GetDummy2Registration() {
+  static TfLiteRegistration registration = {
+      .init = nullptr,
+      .free = nullptr,
+      .prepare = nullptr,
+      .invoke = Dummy2Invoke,
+  };
+  return &registration;
+}
+
 TEST(MutableOpResolverTest, FinOp) {
   MutableOpResolver resolver;
   resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
@@ -119,6 +133,26 @@
   EXPECT_EQ(found_registration, nullptr);
 }
 
+TEST(MutableOpResolverTest, AddAll) {
+  MutableOpResolver resolver1;
+  resolver1.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
+  resolver1.AddBuiltin(BuiltinOperator_MUL, GetDummy2Registration());
+
+  MutableOpResolver resolver2;
+  resolver2.AddBuiltin(BuiltinOperator_SUB, GetDummyRegistration());
+  resolver2.AddBuiltin(BuiltinOperator_ADD, GetDummy2Registration());
+
+  // resolver2's ADD op should replace resolver1's ADD op, while augmenting
+  // non-overlapping ops.
+  resolver1.AddAll(resolver2);
+  ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->invoke,
+            GetDummy2Registration()->invoke);
+  ASSERT_EQ(resolver1.FindOp(BuiltinOperator_MUL, 1)->invoke,
+            GetDummy2Registration()->invoke);
+  ASSERT_EQ(resolver1.FindOp(BuiltinOperator_SUB, 1)->invoke,
+            GetDummyRegistration()->invoke);
+}
+
 }  // namespace
 }  // namespace tflite
 
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 0656884..b398e87 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -516,6 +516,10 @@
         nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED;
         break;
       case tflite::BuiltinOperator_RESHAPE:
+        if (node.inputs->size != 2) {
+          logError("NNAPI only supports 2-input RESHAPE");
+          return kTfLiteError;
+        }
         nn_op_type = ANEURALNETWORKS_RESHAPE;
         // add_reshape_params(node.builtin_data);
         break;
@@ -677,6 +681,8 @@
       case tflite::BuiltinOperator_FLOOR_DIV:
       case tflite::BuiltinOperator_REDUCE_ANY:
       case tflite::BuiltinOperator_SQUARE:
+      case tflite::BuiltinOperator_ZEROS_LIKE:
+      case tflite::BuiltinOperator_FILL:
         logError("Op code %d is currently not delegated to NNAPI", builtin);
         return kTfLiteError;
         break;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index f0db22d..3da3188 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -174,6 +174,8 @@
   FLOOR_DIV = 90,
   REDUCE_ANY = 91,
   SQUARE = 92,
+  ZEROS_LIKE = 93,
+  FILL = 94,
 }
 
 // Options for the builtin operators.
@@ -244,6 +246,8 @@
   UnpackOptions,
   FloorDivOptions,
   SquareOptions,
+  ZerosLikeOptions,
+  FillOptions,
 }
 
 enum Padding : byte { SAME, VALID }
@@ -588,6 +592,12 @@
 table SquareOptions {
 }
 
+table ZerosLikeOptions {
+}
+
+table FillOptions {
+}
+
 // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
 // builtin, or a string if the operator is custom.
 table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 8c086a5..c7a59ca 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -229,6 +229,12 @@
 struct SquareOptions;
 struct SquareOptionsT;
 
+struct ZerosLikeOptions;
+struct ZerosLikeOptionsT;
+
+struct FillOptions;
+struct FillOptionsT;
+
 struct OperatorCode;
 struct OperatorCodeT;
 
@@ -387,11 +393,13 @@
   BuiltinOperator_FLOOR_DIV = 90,
   BuiltinOperator_REDUCE_ANY = 91,
   BuiltinOperator_SQUARE = 92,
+  BuiltinOperator_ZEROS_LIKE = 93,
+  BuiltinOperator_FILL = 94,
   BuiltinOperator_MIN = BuiltinOperator_ADD,
-  BuiltinOperator_MAX = BuiltinOperator_SQUARE
+  BuiltinOperator_MAX = BuiltinOperator_FILL
 };
 
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[92] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[94] {
   static BuiltinOperator values[] = {
     BuiltinOperator_ADD,
     BuiltinOperator_AVERAGE_POOL_2D,
@@ -484,7 +492,9 @@
     BuiltinOperator_REDUCE_MIN,
     BuiltinOperator_FLOOR_DIV,
     BuiltinOperator_REDUCE_ANY,
-    BuiltinOperator_SQUARE
+    BuiltinOperator_SQUARE,
+    BuiltinOperator_ZEROS_LIKE,
+    BuiltinOperator_FILL
   };
   return values;
 }
@@ -584,6 +594,8 @@
     "FLOOR_DIV",
     "REDUCE_ANY",
     "SQUARE",
+    "ZEROS_LIKE",
+    "FILL",
     nullptr
   };
   return names;
@@ -662,11 +674,13 @@
   BuiltinOptions_UnpackOptions = 64,
   BuiltinOptions_FloorDivOptions = 65,
   BuiltinOptions_SquareOptions = 66,
+  BuiltinOptions_ZerosLikeOptions = 67,
+  BuiltinOptions_FillOptions = 68,
   BuiltinOptions_MIN = BuiltinOptions_NONE,
-  BuiltinOptions_MAX = BuiltinOptions_SquareOptions
+  BuiltinOptions_MAX = BuiltinOptions_FillOptions
 };
 
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[67] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[69] {
   static BuiltinOptions values[] = {
     BuiltinOptions_NONE,
     BuiltinOptions_Conv2DOptions,
@@ -734,7 +748,9 @@
     BuiltinOptions_LogicalNotOptions,
     BuiltinOptions_UnpackOptions,
     BuiltinOptions_FloorDivOptions,
-    BuiltinOptions_SquareOptions
+    BuiltinOptions_SquareOptions,
+    BuiltinOptions_ZerosLikeOptions,
+    BuiltinOptions_FillOptions
   };
   return values;
 }
@@ -808,6 +824,8 @@
     "UnpackOptions",
     "FloorDivOptions",
     "SquareOptions",
+    "ZerosLikeOptions",
+    "FillOptions",
     nullptr
   };
   return names;
@@ -1086,6 +1104,14 @@
   static const BuiltinOptions enum_value = BuiltinOptions_SquareOptions;
 };
 
+template<> struct BuiltinOptionsTraits<ZerosLikeOptions> {
+  static const BuiltinOptions enum_value = BuiltinOptions_ZerosLikeOptions;
+};
+
+template<> struct BuiltinOptionsTraits<FillOptions> {
+  static const BuiltinOptions enum_value = BuiltinOptions_FillOptions;
+};
+
 struct BuiltinOptionsUnion {
   BuiltinOptions type;
   void *value;
@@ -1645,6 +1671,22 @@
     return type == BuiltinOptions_SquareOptions ?
       reinterpret_cast<const SquareOptionsT *>(value) : nullptr;
   }
+  ZerosLikeOptionsT *AsZerosLikeOptions() {
+    return type == BuiltinOptions_ZerosLikeOptions ?
+      reinterpret_cast<ZerosLikeOptionsT *>(value) : nullptr;
+  }
+  const ZerosLikeOptionsT *AsZerosLikeOptions() const {
+    return type == BuiltinOptions_ZerosLikeOptions ?
+      reinterpret_cast<const ZerosLikeOptionsT *>(value) : nullptr;
+  }
+  FillOptionsT *AsFillOptions() {
+    return type == BuiltinOptions_FillOptions ?
+      reinterpret_cast<FillOptionsT *>(value) : nullptr;
+  }
+  const FillOptionsT *AsFillOptions() const {
+    return type == BuiltinOptions_FillOptions ?
+      reinterpret_cast<const FillOptionsT *>(value) : nullptr;
+  }
 };
 
 bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -5888,6 +5930,86 @@
 
 flatbuffers::Offset<SquareOptions> CreateSquareOptions(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
 
+struct ZerosLikeOptionsT : public flatbuffers::NativeTable {
+  typedef ZerosLikeOptions TableType;
+  ZerosLikeOptionsT() {
+  }
+};
+
+struct ZerosLikeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef ZerosLikeOptionsT NativeTableType;
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           verifier.EndTable();
+  }
+  ZerosLikeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(ZerosLikeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<ZerosLikeOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ZerosLikeOptionsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  explicit ZerosLikeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ZerosLikeOptionsBuilder &operator=(const ZerosLikeOptionsBuilder &);
+  flatbuffers::Offset<ZerosLikeOptions> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<ZerosLikeOptions>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<ZerosLikeOptions> CreateZerosLikeOptions(
+    flatbuffers::FlatBufferBuilder &_fbb) {
+  ZerosLikeOptionsBuilder builder_(_fbb);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<ZerosLikeOptions> CreateZerosLikeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct FillOptionsT : public flatbuffers::NativeTable {
+  typedef FillOptions TableType;
+  FillOptionsT() {
+  }
+};
+
+struct FillOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef FillOptionsT NativeTableType;
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           verifier.EndTable();
+  }
+  FillOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(FillOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<FillOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct FillOptionsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  explicit FillOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  FillOptionsBuilder &operator=(const FillOptionsBuilder &);
+  flatbuffers::Offset<FillOptions> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<FillOptions>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<FillOptions> CreateFillOptions(
+    flatbuffers::FlatBufferBuilder &_fbb) {
+  FillOptionsBuilder builder_(_fbb);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<FillOptions> CreateFillOptions(flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
 struct OperatorCodeT : public flatbuffers::NativeTable {
   typedef OperatorCode TableType;
   BuiltinOperator builtin_code;
@@ -6219,6 +6341,12 @@
   const SquareOptions *builtin_options_as_SquareOptions() const {
     return builtin_options_type() == BuiltinOptions_SquareOptions ? static_cast<const SquareOptions *>(builtin_options()) : nullptr;
   }
+  const ZerosLikeOptions *builtin_options_as_ZerosLikeOptions() const {
+    return builtin_options_type() == BuiltinOptions_ZerosLikeOptions ? static_cast<const ZerosLikeOptions *>(builtin_options()) : nullptr;
+  }
+  const FillOptions *builtin_options_as_FillOptions() const {
+    return builtin_options_type() == BuiltinOptions_FillOptions ? static_cast<const FillOptions *>(builtin_options()) : nullptr;
+  }
   const flatbuffers::Vector<uint8_t> *custom_options() const {
     return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
   }
@@ -6514,6 +6642,14 @@
   return builtin_options_as_SquareOptions();
 }
 
+template<> inline const ZerosLikeOptions *Operator::builtin_options_as<ZerosLikeOptions>() const {
+  return builtin_options_as_ZerosLikeOptions();
+}
+
+template<> inline const FillOptions *Operator::builtin_options_as<FillOptions>() const {
+  return builtin_options_as_FillOptions();
+}
+
 struct OperatorBuilder {
   flatbuffers::FlatBufferBuilder &fbb_;
   flatbuffers::uoffset_t start_;
@@ -8782,6 +8918,52 @@
       _fbb);
 }
 
+inline ZerosLikeOptionsT *ZerosLikeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new ZerosLikeOptionsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void ZerosLikeOptions::UnPackTo(ZerosLikeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+}
+
+inline flatbuffers::Offset<ZerosLikeOptions> ZerosLikeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateZerosLikeOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ZerosLikeOptions> CreateZerosLikeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ZerosLikeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  return tflite::CreateZerosLikeOptions(
+      _fbb);
+}
+
+inline FillOptionsT *FillOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new FillOptionsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void FillOptions::UnPackTo(FillOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+}
+
+inline flatbuffers::Offset<FillOptions> FillOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateFillOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<FillOptions> CreateFillOptions(flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FillOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  return tflite::CreateFillOptions(
+      _fbb);
+}
+
 inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
   auto _o = new OperatorCodeT();
   UnPackTo(_o, _resolver);
@@ -9235,6 +9417,14 @@
       auto ptr = reinterpret_cast<const SquareOptions *>(obj);
       return verifier.VerifyTable(ptr);
     }
+    case BuiltinOptions_ZerosLikeOptions: {
+      auto ptr = reinterpret_cast<const ZerosLikeOptions *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
+    case BuiltinOptions_FillOptions: {
+      auto ptr = reinterpret_cast<const FillOptions *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
     default: return false;
   }
 }
@@ -9517,6 +9707,14 @@
       auto ptr = reinterpret_cast<const SquareOptions *>(obj);
       return ptr->UnPack(resolver);
     }
+    case BuiltinOptions_ZerosLikeOptions: {
+      auto ptr = reinterpret_cast<const ZerosLikeOptions *>(obj);
+      return ptr->UnPack(resolver);
+    }
+    case BuiltinOptions_FillOptions: {
+      auto ptr = reinterpret_cast<const FillOptions *>(obj);
+      return ptr->UnPack(resolver);
+    }
     default: return nullptr;
   }
 }
@@ -9787,6 +9985,14 @@
       auto ptr = reinterpret_cast<const SquareOptionsT *>(value);
       return CreateSquareOptions(_fbb, ptr, _rehasher).Union();
     }
+    case BuiltinOptions_ZerosLikeOptions: {
+      auto ptr = reinterpret_cast<const ZerosLikeOptionsT *>(value);
+      return CreateZerosLikeOptions(_fbb, ptr, _rehasher).Union();
+    }
+    case BuiltinOptions_FillOptions: {
+      auto ptr = reinterpret_cast<const FillOptionsT *>(value);
+      return CreateFillOptions(_fbb, ptr, _rehasher).Union();
+    }
     default: return 0;
   }
 }
@@ -10057,6 +10263,14 @@
       value = new SquareOptionsT(*reinterpret_cast<SquareOptionsT *>(u.value));
       break;
     }
+    case BuiltinOptions_ZerosLikeOptions: {
+      value = new ZerosLikeOptionsT(*reinterpret_cast<ZerosLikeOptionsT *>(u.value));
+      break;
+    }
+    case BuiltinOptions_FillOptions: {
+      value = new FillOptionsT(*reinterpret_cast<FillOptionsT *>(u.value));
+      break;
+    }
     default:
       break;
   }
@@ -10394,6 +10608,16 @@
       delete ptr;
       break;
     }
+    case BuiltinOptions_ZerosLikeOptions: {
+      auto ptr = reinterpret_cast<ZerosLikeOptionsT *>(value);
+      delete ptr;
+      break;
+    }
+    case BuiltinOptions_FillOptions: {
+      auto ptr = reinterpret_cast<FillOptionsT *>(value);
+      delete ptr;
+      break;
+    }
     default: break;
   }
   value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 3754b58..014c80b 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -2834,6 +2834,31 @@
   make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
 
 
+def make_zeros_like_tests(zip_path):
+  """Make a set of tests to do zeros_like."""
+
+  test_parameters = [{
+      "input_dtype": [tf.float32, tf.int32, tf.int64],
+      "input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
+  }]
+
+  def build_graph(parameters):
+    """Build the zeros_like op testing graph."""
+    input_tensor = tf.placeholder(
+        dtype=parameters["input_dtype"],
+        name="input",
+        shape=parameters["input_shape"])
+    out = tf.zeros_like(input_tensor)
+    return [input_tensor], [out]
+
+  def build_inputs(parameters, sess, inputs, outputs):
+    values = create_tensor_data(parameters["input_dtype"],
+                                parameters["input_shape"])
+    return [values], sess.run(outputs, feed_dict=dict(zip(inputs, [values])))
+
+  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
 def _make_elementwise_tests(op):
   """Make a set of tests to do element-wise operations."""
 
diff --git a/tensorflow/contrib/lite/tflite_static.bp b/tensorflow/contrib/lite/tflite_static.bp
index 6036413..45b9237 100644
--- a/tensorflow/contrib/lite/tflite_static.bp
+++ b/tensorflow/contrib/lite/tflite_static.bp
@@ -100,6 +100,7 @@
         "kernels/unidirectional_sequence_lstm.cc",
         "kernels/unidirectional_sequence_rnn.cc",
         "kernels/unpack.cc",
+        "kernels/zeros_like.cc",
         "kernels/internal/kernel_utils.cc",
         "kernels/internal/tensor_utils.cc",
         "kernels/internal/quantization_util.cc",
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index bea90f1..96b88b6 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -347,6 +347,7 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:graph",
         "//tensorflow/core:lib",
+        "//tensorflow/core:ops",
         "//tensorflow/core:protos_all_cc",
         "@com_google_googletest//:gtest_main",
     ],
@@ -407,8 +408,11 @@
         ":toco_port",
         ":toco_tooling",
         ":types_proto_cc",
-        "//tensorflow/core:lib",
         "@com_google_absl//absl/strings",
+        "//tensorflow/core:lib",
+        # We cannot embed the core:ops dependency directly into :toco_tooling as
+        # it can conflict with downstream deps when toco is used as a library.
+        "//tensorflow/core:ops",
     ],
 )
 
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index b52a792..61e9106 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -470,6 +470,17 @@
   strides.mutable_list()->add_i(src_op.stride_height);
   strides.mutable_list()->add_i(src_op.stride_width);
   strides.mutable_list()->add_i(1);
+  // TODO(b/116063589): To return a working TF GraphDef, we should be returning
+  // the correct SpaceToBatchNd and BatchToSpaceND operation before and after
+  // the conv since TF doesn't support dilations.
+  if ((src_op.dilation_width_factor != 1) ||
+      (src_op.dilation_height_factor != 1)) {
+    auto& dilations = (*dc2d_op->mutable_attr())["dilations"];
+    dilations.mutable_list()->add_i(1);
+    dilations.mutable_list()->add_i(src_op.dilation_height_factor);
+    dilations.mutable_list()->add_i(src_op.dilation_width_factor);
+    dilations.mutable_list()->add_i(1);
+  }
   string padding;
   if (src_op.padding.type == PaddingType::kSame) {
     padding = "SAME";
@@ -1968,6 +1979,19 @@
   (*unpack_op->mutable_attr())["axis"].set_i(src_op.axis);
 }
 
+void ConvertZerosLikeOperator(const Model& model,
+                              const TensorFlowZerosLikeOperator& src_op,
+                              const char* op_name, GraphDef* tensorflow_graph) {
+  tensorflow::NodeDef* zeros_like_op = tensorflow_graph->add_node();
+  zeros_like_op->set_op(op_name);
+  zeros_like_op->set_name(src_op.outputs[0]);
+  DCHECK_EQ(src_op.inputs.size(), 1);
+  *zeros_like_op->add_input() = src_op.inputs[0];
+  const tensorflow::DataType data_type =
+      GetTensorFlowDataType(model, src_op.inputs[0]);
+  (*zeros_like_op->mutable_attr())["T"].set_type(data_type);
+}
+
 void ConvertOperator(const Model& model, const Operator& src_op,
                      GraphDef* tensorflow_graph) {
   if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -2233,6 +2257,10 @@
   } else if (src_op.type == OperatorType::kUnpack) {
     ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op),
                           "Unpack", tensorflow_graph);
+  } else if (src_op.type == OperatorType::kZerosLike) {
+    ConvertZerosLikeOperator(
+        model, static_cast<const TensorFlowZerosLikeOperator&>(src_op),
+        "ZerosLike", tensorflow_graph);
   } else {
     LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
   }
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
index 84680b9..aba7536 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
@@ -38,7 +38,7 @@
     examples below use `tflite_convert` for simplicity.
     *   Example: `tflite_convert --output_file=...`
 *   `bazel`: In order to run the latest version of TOCO, [clone the TensorFlow
-    repository](https://www.tensorflow.org/install/install_sources#clone_the_tensorflow_repository)
+    repository](https://www.tensorflow.org/install/source)
     and use `bazel`. This is the recommended approach for converting models that
     utilize new features that were not supported by TOCO in TensorFlow 1.9.
     *   Example: `bazel run
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index 51f808d..910fa4c 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -260,7 +260,7 @@
 In order to run the latest version of the TOCO Python API, clone the TensorFlow
 repository, configure the installation, and build and install the pip package.
 Detailed instructions are available
-[here](https://www.tensorflow.org/install/install_sources).
+[here](https://www.tensorflow.org/install/source).
 
 ### Converting models prior to TensorFlow 1.9. <a name="pre-tensorflow-1.9"></a>
 
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index fdd0632..4d213b3 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -133,7 +133,6 @@
 DECLARE_GRAPH_TRANSFORMATION(MergeReshapeIntoPrecedingTranspose)
 DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
 DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
-DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
 DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
 DECLARE_GRAPH_TRANSFORMATION(MoveBinaryOperatorBeforeReshape)
 DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants)
@@ -266,6 +265,17 @@
   bool has_default_ranges_flag_ = false;
 };
 
+class IdentifyDilatedConv : public GraphTransformation {
+ public:
+  bool Run(Model* model, std::size_t op_index) override;
+  const char* Name() const override { return "IdentifyDilatedConv"; }
+  bool identify_depthwise_conv() const { return identify_depthwise_conv_; }
+  void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; }
+
+ private:
+  bool identify_depthwise_conv_ = true;
+};
+
 #undef DECLARE_GRAPH_TRANSFORMATION
 
 }  // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
index d49857c..aac77eb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
@@ -53,50 +53,11 @@
 // thrown in just for the extra headache. Padding adapts non-conforming input
 // sizes, and can be discarded. The bias is necessary, so is kept.
 
-bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
-  const auto it = model->operators.begin() + op_index;
-  auto* stb_op = it->get();
-
-  // 1. IDENTIFY OPERATORS
-  // ***************************************************************************
-  // SpaceToBatch Op.
-  if (stb_op->type != OperatorType::kSpaceToBatchND) {
-    return false;
-  }
-  if (stb_op->inputs.size() != 3) {
-    return false;
-  }
-  CHECK_EQ(stb_op->outputs.size(), 1);
-  // Extract the dilation factor from Input[1] of SpaceToBatch
-  // TODO(mjmatthews): Support 2D dilation factors.
-  const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
-  if (!block_shape_array.buffer) {
-    return false;
-  }
-  CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
-  int dilation_factor =
-      block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
-
-  // Expand Op
-  auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
-  if (!post_stb_op) {
-    return false;
-  }
-  bool has_expand_op = false;
-  if (post_stb_op->type == OperatorType::kExpandDims) {
-    has_expand_op = true;
-    CHECK_EQ(post_stb_op->inputs.size(), 2);
-    CHECK_EQ(post_stb_op->outputs.size(), 1);
-  }
-
-  // Conv Op
-  const string& input_of_conv_op =
-      has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
-  auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
-  if (conv_base_op->type != OperatorType::kConv) {
-    return false;
-  }
-  auto* conv_op = static_cast<ConvOperator*>(conv_base_op);
+template <typename T>
+bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
+                        Operator* post_stb_op, bool has_expand_op,
+                        int dilation_factor) {
+  auto* conv_op = static_cast<T*>(conv_base_op);
   if (conv_op->inputs.size() != 2) {
     // The conv op must only have weights, no bias.
     return false;
@@ -158,8 +119,6 @@
   CHECK_EQ(bias_add_op->inputs.size(), 2);
   CHECK_EQ(bias_add_op->outputs.size(), 1);
 
-  LOG(INFO) << "Identified sub-network emulating dilated convolution.";
-
   // 2. RE-WIRE OPERATORS
   // ***************************************************************************
   // Re-use the existing Conv2D op.
@@ -206,9 +165,71 @@
   DeleteArrayIfUnused(stb_op_inputs[1], model);
   DeleteArrayIfUnused(stb_op_inputs[2], model);
 
-  LOG(INFO) << "Replaced with Dilated Conv2D op outputting \""
-            << conv_op->outputs[0] << "\".";
   return true;
 }
 
+bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
+  const auto it = model->operators.begin() + op_index;
+  auto* stb_op = it->get();
+
+  // 1. IDENTIFY OPERATORS
+  // ***************************************************************************
+  // SpaceToBatch Op.
+  if (stb_op->type != OperatorType::kSpaceToBatchND) {
+    return false;
+  }
+  if (stb_op->inputs.size() != 3) {
+    return false;
+  }
+  CHECK_EQ(stb_op->outputs.size(), 1);
+  // Extract the dilation factor from Input[1] of SpaceToBatch
+  // TODO(mjmatthews): Support 2D dilation factors.
+  const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
+  if (!block_shape_array.buffer) {
+    return false;
+  }
+  CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
+  int dilation_factor =
+      block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
+
+  // Expand Op
+  auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
+  if (!post_stb_op) {
+    return false;
+  }
+  bool has_expand_op = false;
+  if (post_stb_op->type == OperatorType::kExpandDims) {
+    has_expand_op = true;
+    CHECK_EQ(post_stb_op->inputs.size(), 2);
+    CHECK_EQ(post_stb_op->outputs.size(), 1);
+  }
+
+  // Conv Op
+  const string& input_of_conv_op =
+      has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
+  auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
+  bool changed = false;
+  if (conv_base_op->type == OperatorType::kConv) {
+    changed = ResolveDilatedConv<ConvOperator>(model, conv_base_op, stb_op,
+                                               post_stb_op, has_expand_op,
+                                               dilation_factor);
+    if (changed) {
+      LOG(INFO) << "Replaced sub-network with Dilated Conv2D op outputting \""
+                << conv_base_op->outputs[0] << "\".";
+    }
+  } else if (identify_depthwise_conv_ &&
+             conv_base_op->type == OperatorType::kDepthwiseConv) {
+    changed = ResolveDilatedConv<DepthwiseConvOperator>(
+        model, conv_base_op, stb_op, post_stb_op, has_expand_op,
+        dilation_factor);
+    if (changed) {
+      LOG(INFO)
+          << "Replaced sub-netork with Dilated DepthwiseConv2D op outputting \""
+          << conv_base_op->outputs[0] << "\".";
+    }
+  }
+
+  return changed;
+}
+
 }  // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index f103bb9..f943da6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -285,7 +285,8 @@
   const int kheight = weights_shape.dims(1);
   const int kwidth = weights_shape.dims(2);
   ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
-                   op->stride_height, 1, 1, op->padding.type,
+                   op->stride_height, op->dilation_width_factor,
+                   op->dilation_height_factor, op->padding.type,
                    model->GetArray(output_name).mutable_shape(),
                    &op->padding.GetOrCreateFixedPadding());
 }
@@ -1655,6 +1656,7 @@
     case OperatorType::kLogicalAnd:
     case OperatorType::kLogicalNot:
     case OperatorType::kLogicalOr:
+    case OperatorType::kZerosLike:
       ProcessSimpleOperator(model, op, 0);
       break;
     case OperatorType::kGather:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
index 8266e2c..8e150db 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
@@ -25,29 +25,57 @@
 
 namespace toco {
 
+namespace {
+
+void RenameArray(Model* model, const string& oldname,
+                 const string& desired_newname) {
+  const string& newname = AvailableArrayName(*model, desired_newname);
+  auto& arrays = model->GetMutableArrayMap();
+  arrays[newname] = std::move(arrays[oldname]);
+  arrays.erase(oldname);
+  for (const auto& op : model->operators) {
+    for (string& input : op->inputs) {
+      if (input == oldname) {
+        input = newname;
+      }
+    }
+    for (string& output : op->outputs) {
+      if (output == oldname) {
+        output = newname;
+      }
+    }
+  }
+}
+
+}  // namespace
+
 // Reorder the elements of an input_array according to the input_axes_order and
 // output_axes_order. Then adjust the shapes of the input and output arrays
 // accordingly. Note that input_array must have a buffer (that is, it is a
 // constant array).
 template <typename T, ArrayDataType DataType>
 void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order,
-                 Array* input_array, Array* output_array) {
-  CHECK(input_array->buffer->type == DataType);
-  CHECK(!output_array->buffer);
-  auto& input_data = input_array->GetMutableBuffer<DataType>().data;
-  std::vector<T> reordered_data;
-  reordered_data.resize(RequiredBufferSizeForShape(output_array->shape()));
+                 const Array& input_array, Array* output_array) {
+  DCHECK(input_array.buffer->type == DataType);
+  DCHECK(!output_array->buffer);
+  const auto& input_data = input_array.GetBuffer<DataType>().data;
+  auto& output_data = output_array->GetMutableBuffer<DataType>().data;
+  output_data.resize(RequiredBufferSizeForShape(output_array->shape()));
   // TODO(b/62904716) Shapes should be used directly.
-  Shape input_shape = input_array->shape();
+  Shape input_shape = input_array.shape();
   Shape output_shape = output_array->shape();
   if (AxesCount(input_axes_order) == 2) {
     UnextendShape(&input_shape, 2);
     UnextendShape(&output_shape, 2);
   }
   ShuffleArray(input_shape, input_axes_order, output_axes_order, output_shape,
-               input_data.data(), reordered_data.data());
-  input_data = reordered_data;
-  input_array->copy_shape(output_array->shape());
+               input_data.data(), output_data.data());
+  if (input_array.minmax) {
+    output_array->GetOrCreateMinMax() = input_array.GetMinMax();
+  }
+  if (input_array.narrow_range) {
+    output_array->narrow_range = true;
+  }
 }
 
 bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
@@ -57,8 +85,11 @@
     return false;
   }
   auto* reorder_op = static_cast<ReorderAxesOperator*>(op);
-  const auto& input_array_name = reorder_op->inputs[0];
-  const auto& output_array_name = reorder_op->outputs[0];
+
+  // Intentionally copies, not references.
+  const string input_array_name = reorder_op->inputs[0];
+  const string output_array_name = reorder_op->outputs[0];
+
   auto& input_array = model->GetArray(input_array_name);
   auto& output_array = model->GetArray(output_array_name);
   if (!input_array.buffer) {
@@ -72,31 +103,23 @@
   if (input_array.buffer->type == ArrayDataType::kFloat) {
     ReorderAxes<float, ArrayDataType::kFloat>(reorder_op->input_axes_order,
                                               reorder_op->output_axes_order,
-                                              &input_array, &output_array);
-  } else if (input_array.buffer->type == ArrayDataType::kInt32) {
+                                              input_array, &output_array);
+  } else if (input_array.buffer->type == ArrayDataType::kUint8) {
+    // TODO(benoitjacob): This path seems unused.
+    // ReorderAxes is only used when importing from
+    // TensorFlow GraphDef, which does not support quantized nodes.
     ReorderAxes<uint8, ArrayDataType::kUint8>(reorder_op->input_axes_order,
                                               reorder_op->output_axes_order,
-                                              &input_array, &output_array);
+                                              input_array, &output_array);
   } else {
     LOG(FATAL) << "Cannot ReorderAxes unless input buffer is float or uint8.";
   }
 
-  input_array.copy_shape(output_array.shape());
-
-  // Update the edges of the graph to point to the input array
-  for (const auto& other_op : model->operators) {
-    for (auto& input : other_op->inputs) {
-      if (input == output_array_name) {
-        input = input_array_name;
-      }
-    }
-  }
-
   AddMessageF("Reordered axes for array %s", input_array_name);
 
-  // Remove the op and output array.
-  model->EraseArray(output_array_name);
-  model->operators.erase(it);
+  DeleteOpAndArraysIfUnused(model, op);
+  RenameArray(model, output_array_name, input_array_name);
+
   return true;
 }
 
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index fcf30bd..65346c4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -24,6 +24,37 @@
 
 namespace toco {
 
+namespace {
+
+TransposeOperator* FindTransposeOpWithInput(const Model& model,
+                                            const string& array_name) {
+  for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+    Operator* op = it->get();
+    if (op->type != OperatorType::kTranspose) {
+      continue;
+    }
+    if (op->inputs[0] != array_name) {
+      continue;
+    }
+    const auto& permutation_array = model.GetArray(op->inputs[1]);
+    if (permutation_array.data_type != ArrayDataType::kInt32) {
+      continue;
+    }
+    const auto& permutation_data =
+        permutation_array.GetBuffer<ArrayDataType::kInt32>().data;
+    if (permutation_data.size() != 2) {
+      continue;
+    }
+    if (permutation_data[0] != 1 || permutation_data[1] != 0) {
+      continue;
+    }
+    return static_cast<TransposeOperator*>(op);
+  }
+  return nullptr;
+}
+
+}  // namespace
+
 bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
   auto matmul_it = model->operators.begin() + op_index;
   if (matmul_it->get()->type != OperatorType::kMatMul) {
@@ -37,7 +68,13 @@
   // TransposeOperator.  However, the second input is supposed to be 2D, so we
   // can actually handle transposition of that matrix, which happens to be more
   // common anyway.
-  CHECK(!matmul_op->transpose_a);
+  if (matmul_op->transpose_a) {
+    AddMessageF(
+        "Not replacing %s by a FullyConnected operator, because it has "
+        "the transpose_a attribute",
+        LogName(*matmul_op));
+    return false;
+  }
 
   // Reorder the axes on the second input. TensorFlow uses row-major ordering
   // on both inputs, however this is inefficient for the FullyConnected
@@ -46,18 +83,35 @@
   string input_lhs = matmul_op->inputs[0];
   string input_rhs = matmul_op->inputs[1];
   if (!matmul_op->transpose_b) {
-    auto* transpose_op = new TransposeOperator;
-    transpose_op->inputs = {
-        matmul_op->inputs[1],
-        CreateInt32Array(model,
-                         AvailableArrayName(
-                             *model, matmul_op->inputs[1] + "/transpose/perm"),
-                         {1, 0})};
-    transpose_op->outputs = {
-        AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")};
-    model->GetOrCreateArray(transpose_op->outputs[0]);
-    model->operators.emplace(matmul_it, transpose_op);
-
+    // Need to transpose input_rhs, by inserting a TransposeOperator.
+    // First, check if there already is a TransposeOperator transposing that
+    // array, so we can just reuse it.
+    auto* transpose_op = FindTransposeOpWithInput(*model, input_rhs);
+    if (!transpose_op) {
+      AddMessageF(
+          "While replacing %s by a FullyConnected operator, created new "
+          "Transpose op wrapping RHS input array %s",
+          LogName(*matmul_op), input_rhs);
+      // No such TransposeOperator found. Create one now.
+      transpose_op = new TransposeOperator;
+      transpose_op->inputs = {
+          input_rhs,
+          CreateInt32Array(
+              model, AvailableArrayName(*model, input_rhs + "/transpose/perm"),
+              {1, 0})};
+      transpose_op->outputs = {
+          AvailableArrayName(*model, input_rhs + "/transpose")};
+      model->GetOrCreateArray(transpose_op->outputs[0]);
+      model->operators.emplace(matmul_it, transpose_op);
+      // Sanity check
+      DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_rhs));
+    } else {
+      AddMessageF(
+          "While replacing %s by a FullyConnected operator, reused existing "
+          "Transpose op wrapping RHS input array %s",
+          LogName(*matmul_op), input_rhs);
+    }
+    // Re-wire: have the matmul consume the transposed array.
     input_rhs = transpose_op->outputs[0];
   }
 
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 9bc23c4..e02d000 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -58,6 +58,7 @@
 using tensorflow::DT_UINT8;
 using tensorflow::GraphDef;
 using tensorflow::NodeDef;
+using tensorflow::OpRegistry;
 using tensorflow::TensorProto;
 using tensorflow::TensorShapeProto;
 
@@ -68,6 +69,13 @@
   return node.attr().count(attr_name) > 0;
 }
 
+bool HasWildcardDimension(const TensorShapeProto& shape) {
+  for (const auto& dim : shape.dim()) {
+    if (dim.size() == -1) return true;
+  }
+  return false;
+}
+
 const string& GetStringAttr(const NodeDef& node, const string& attr_name) {
   CHECK(HasAttr(node, attr_name));
   const auto& attr = node.attr().at(attr_name);
@@ -633,6 +641,23 @@
   CHECK_EQ(strides.i(3), 1);
   conv->stride_height = strides.i(1);
   conv->stride_width = strides.i(2);
+  if (HasAttr(node, "dilations")) {
+    const auto& dilations = GetListAttr(node, "dilations");
+    TF_RETURN_IF_ERROR(
+        ExpectValue(dilations.i_size(), 4, "number of dilations"));
+    if (dilations.i(0) != 1 || dilations.i(3) != 1) {
+      return tensorflow::errors::InvalidArgument(absl::StrCat(
+          "Can only import Conv ops with dilation along the height "
+          "(1st) or width (2nd) axis. TensorFlow op \"",
+          node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
+          dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
+    }
+    conv->dilation_height_factor = dilations.i(1);
+    conv->dilation_width_factor = dilations.i(2);
+  } else {
+    conv->dilation_height_factor = 1;
+    conv->dilation_width_factor = 1;
+  }
   const auto& padding = GetStringAttr(node, "padding");
   if (padding == "SAME") {
     conv->padding.type = PaddingType::kSame;
@@ -1053,15 +1078,27 @@
       "_support_output_type_float_in_quantized_op";
 
   LOG(INFO) << "Converting unsupported operation: " << node.op();
+
   auto* op = new TensorFlowUnsupportedOperator;
+  op->tensorflow_op = node.op();
+  node.SerializeToString(&op->tensorflow_node_def);
+  model->operators.emplace_back(op);
+
+  // Parse inputs.
   const int num_inputs = GetInputsCount(node, tf_import_flags);
   for (int i = 0; i < num_inputs; ++i) {
     op->inputs.push_back(node.input(i));
   }
-  op->outputs.push_back(node.name());
-  op->tensorflow_op = node.op();
-  node.SerializeToString(&op->tensorflow_node_def);
-  model->operators.emplace_back(op);
+
+  // Parse outputs.
+  op->outputs.push_back(node.name());  // Implicit :0.
+  const tensorflow::OpDef* op_def = nullptr;
+  if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
+    for (int i = 1; i < op_def->output_arg_size(); ++i) {
+      op->outputs.push_back(absl::StrCat(node.name(), ":", i));
+    }
+  }
+
   // Parse if the op supports quantization
   if (HasAttr(node, kAttrOutputQuantized)) {
     op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
@@ -1071,6 +1108,8 @@
     op->support_output_type_float_in_quantized_op =
         GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
   }
+
+  // Parse output type(s).
   if (HasAttr(node, kAttrOutputTypes)) {
     const auto& output_types = GetListAttr(node, kAttrOutputTypes);
     for (int i = 0; i < output_types.type_size(); ++i) {
@@ -1079,14 +1118,40 @@
   } else if (HasAttr(node, "Tout")) {
     const auto& output_type = GetDataTypeAttr(node, "Tout");
     op->output_data_types.push_back(ConvertDataType(output_type));
+  } else if (op_def != nullptr) {
+    for (const auto& output_arg : op_def->output_arg()) {
+      if (HasAttr(node, output_arg.type_attr())) {
+        op->output_data_types.push_back(
+            ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr())));
+      } else {
+        LOG(INFO) << "Op node missing output type attribute: " << node.name();
+        op->output_data_types.clear();
+        break;
+      }
+    }
+  } else {
+    // TODO(b/113613439): Figure out how to propagate types for custom ops
+    // that have no OpDef.
+    LOG(INFO) << "Unable to determine output type for op: " << node.op();
   }
+
+  // Parse output shape(s).
   if (HasAttr(node, kAttrOutputShapes)) {
     const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
     Shape output_shape;
     for (int i = 0; i < output_shapes.shape_size(); ++i) {
+      const auto& shape = output_shapes.shape(i);
+      // TOCO doesn't yet properly handle shapes with wildcard dimensions.
+      // TODO(b/113613439): Handle shape inference for unsupported ops that have
+      // shapes with wildcard dimensions.
+      if (HasWildcardDimension(shape)) {
+        LOG(INFO) << "Skipping wildcard output shape(s) for node: "
+                  << node.name();
+        op->output_shapes.clear();
+        break;
+      }
       const auto status =
-          ImportShape(output_shapes.shape(i).dim(), /*input_flat_size=*/nullptr,
-                      &output_shape);
+          ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape);
       if (!status.ok()) {
         return status;
       }
@@ -1139,15 +1204,9 @@
   if (node.attr().count("shape")) {
     const auto& shape = GetShapeAttr(node, "shape");
     auto num_dims = shape.dim_size();
-    bool has_wildcard = false;
-    for (std::size_t i = 0; i < num_dims; i++) {
-      if (shape.dim(i).size() == -1) {
-        has_wildcard = true;
-      }
-    }
     // TODO(b/62716978): This logic needs to be revisted.  During dims
     // refactoring it is an interim fix.
-    if (num_dims > 0 && !has_wildcard) {
+    if (num_dims > 0 && !HasWildcardDimension(shape)) {
       auto& dst_array_dims = *array.mutable_shape()->mutable_dims();
       dst_array_dims.resize(num_dims);
       for (std::size_t i = 0; i < num_dims; i++) {
@@ -2023,6 +2082,7 @@
       {"TopKV2", ConvertTopKV2Operator},
       {"Transpose", ConvertSimpleOperator<TransposeOperator, 2>},
       {"Unpack", ConvertUnpackOperator},
+      {"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1>},
   });
 }
 
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index a00e136..8a236d4 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -49,6 +49,39 @@
 
 namespace {
 
+Status ImportNode(const NodeDef& node, Model* model) {
+  const auto converter = internal::GetTensorFlowNodeConverterMap();
+  return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
+                                        converter);
+}
+
+Status ImportNode(const NodeDef& node) {
+  Model model;
+  return ImportNode(node, &model);
+}
+
+NodeDef BuildNode(
+    const std::string& op,
+    const std::vector<std::initializer_list<int>>& output_shapes) {
+  NodeDef node;
+  node.set_op(op);
+  node.set_name("Node1");
+  node.add_input();
+  node.set_input(0, "Node0");
+
+  AttrValue::ListValue* shapes =
+      (*node.mutable_attr())["_output_shapes"].mutable_list();
+  for (const auto& output_shape : output_shapes) {
+    tensorflow::TensorShapeProto* shape = shapes->add_shape();
+    for (int64_t output_shape_dim : output_shape) {
+      auto shape_dim = shape->add_dim();
+      shape_dim->set_size(output_shape_dim);
+    }
+  }
+
+  return node;
+}
+
 class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
  protected:
   ShapeImportTest() {}
@@ -109,12 +142,24 @@
     SetAttrValue(t, &value_attr);
     (*node->mutable_attr())["value"] = value_attr;
   }
+};
 
-  Status ImportNode(const NodeDef& node) {
-    Model model;
-    const auto converter = internal::GetTensorFlowNodeConverterMap();
-    return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model,
-                                          converter);
+class TypeImportTest : public ::testing::TestWithParam<
+                           std::pair<tensorflow::DataType, ArrayDataType>> {
+ protected:
+  TypeImportTest() {}
+
+  void BuildUnaryNode(const std::string& op_name, tensorflow::DataType dtype,
+                      NodeDef* node) {
+    node->set_op(op_name);
+    node->set_name("Node1");
+
+    node->add_input();
+    node->set_input(0, "Node0");
+
+    AttrValue dtype_attr;
+    SetAttrValue(dtype, &dtype_attr);
+    (*node->mutable_attr())["T"] = dtype_attr;
   }
 };
 
@@ -167,5 +212,77 @@
 INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest,
                         ::testing::ValuesIn(TestTypes()));
 
+std::vector<std::pair<tensorflow::DataType, ArrayDataType>> UnaryTestTypes() {
+  return {{DT_FLOAT, ArrayDataType::kFloat},
+          {DT_INT32, ArrayDataType::kInt32},
+          {DT_INT64, ArrayDataType::kInt64}};
+}
+
+TEST_P(TypeImportTest, BasicTypeInference) {
+  NodeDef node;
+  BuildUnaryNode("Atan", GetParam().first, &node);
+
+  Model model;
+  EXPECT_TRUE(ImportNode(node, &model).ok());
+
+  ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+  ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+  const TensorFlowUnsupportedOperator* op =
+      static_cast<const TensorFlowUnsupportedOperator*>(
+          model.operators[0].get());
+  ASSERT_THAT(op->output_data_types, ::testing::ElementsAre(GetParam().second));
+}
+INSTANTIATE_TEST_CASE_P(BasicTypeInference, TypeImportTest,
+                        ::testing::ValuesIn(UnaryTestTypes()));
+
+TEST(ImportTest, FailedTypeInference) {
+  // Create a unary op with no Type ("T") annotation.
+  NodeDef node;
+  node.set_op("Atan");
+  node.set_name("Node1");
+  node.add_input();
+  node.set_input(0, "Node0");
+
+  Model model;
+  EXPECT_TRUE(ImportNode(node, &model).ok());
+
+  ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+  ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+  const TensorFlowUnsupportedOperator* op =
+      static_cast<const TensorFlowUnsupportedOperator*>(
+          model.operators[0].get());
+  ASSERT_TRUE(op->output_data_types.empty());
+}
+
+TEST(ImportTest, UnsupportedOpWithOutputShapes) {
+  // Create an unsupported op with output shapes.
+  Model model;
+  EXPECT_TRUE(ImportNode(BuildNode("Atan", {{1, 2}, {2, 3}}), &model).ok());
+  ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+  ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+  const TensorFlowUnsupportedOperator* op =
+      static_cast<const TensorFlowUnsupportedOperator*>(
+          model.operators[0].get());
+
+  // The output shapes should be imported.
+  ASSERT_EQ(op->output_shapes.size(), 2);
+  ASSERT_THAT(op->output_shapes[0].dims(), ::testing::ElementsAre(1, 2));
+  ASSERT_THAT(op->output_shapes[1].dims(), ::testing::ElementsAre(2, 3));
+}
+
+TEST(ImportTest, UnsupportedOpWithWildcardOutputShapes) {
+  // Create an unsupported op with wildcard output shapes.
+  Model model;
+  EXPECT_TRUE(ImportNode(BuildNode("Atan", {{-1, 2}}), &model).ok());
+  ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+  ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+  const TensorFlowUnsupportedOperator* op =
+      static_cast<const TensorFlowUnsupportedOperator*>(
+          model.operators[0].get());
+
+  // Wildcard shapes aren't yet supported.
+  ASSERT_TRUE(op->output_shapes.empty());
+}
+
 }  // namespace
 }  // namespace toco
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 164b70f..6e207fd 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -150,6 +150,7 @@
   kLogicalOr,
   kCTCBeamSearchDecoder,
   kUnpack,
+  kZerosLike,
 };
 
 // Helper to deal with TensorFlow arrays using a different ordering of
@@ -1849,6 +1850,16 @@
   ArrayDataType dtype = ArrayDataType::kNone;
 };
 
+// ZerosLike operator:
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: tf.zeros_like
+struct TensorFlowZerosLikeOperator : Operator {
+  TensorFlowZerosLikeOperator() : Operator(OperatorType::kZerosLike) {}
+};
+
 // Alloc's are used for transient arrays only. An Alloc specifies which interval
 // of the "transient_data" workspace buffer passed to inference functions, is to
 // be used for the transient array at hand. The 'start' and 'end' values are
@@ -2073,6 +2084,7 @@
     }
   }
   const ArrayMap& GetArrayMap() const { return arrays; }
+  ArrayMap& GetMutableArrayMap() { return arrays; }
 
   int64 ArithmeticOpsCount() const { return ops_count; }
 
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 1061e7c..c59a28b 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1500,6 +1500,8 @@
       "RSQRT", OperatorType::kRsqrt));
   ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
       "SQUARE", OperatorType::kSquare));
+  ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
+      "ZEROS_LIKE", OperatorType::kZerosLike));
 
   return ops;
 }
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 72e50a9..0bc591e 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -146,6 +146,8 @@
   CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
   CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
                                                 OperatorType::kSquare);
+  CheckSimpleOperator<TensorFlowZerosLikeOperator>("ZEROS_LIKE",
+                                                   OperatorType::kZerosLike);
 }
 
 TEST_F(OperatorTest, BuiltinAdd) {
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index a7c1715..a08b024 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -101,7 +101,6 @@
   transformations->Add(new ResolveTensorFlowSwitch);
   transformations->Add(new ResolveTensorFlowConcat);
   transformations->Add(new ResolveMultiplyByZero);
-  transformations->Add(new IdentifyDilatedConv);
   transformations->Add(new IdentifyL2Normalization);
   transformations->Add(new IdentifyL2Pool);
   transformations->Add(new IdentifyRelu1);
@@ -282,6 +281,14 @@
     }
   }
   transformations.Add(new ResolveConstantConcatenation);
+  // TODO(b/116063589): TF GraphDef doesn't support dilations on its depthwise
+  // conv, so we need to make sure we don't convert to dilated depthwise conv
+  // when outputing to TF GraphDef.
+  auto* identify_dilated_conv = new IdentifyDilatedConv;
+  if (output_format == TENSORFLOW_GRAPHDEF) {
+    identify_dilated_conv->set_identify_depthwise_conv(false);
+  }
+  transformations.Add(identify_dilated_conv);
   RunGraphTransformations(model, "general graph transformations",
                           transformations);
 
@@ -367,9 +374,7 @@
   }
 
   // Deduplicate large constant arrays.
-  if (toco_flags.has_dedupe_array_min_size_bytes()) {
-    DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes());
-  }
+  DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes());
 
   LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
 
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 6ab93d9..4a1ae35 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -406,6 +406,7 @@
     HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
     HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
     HANDLE_OPERATORTYPENAME_CASE(Unpack)
+    HANDLE_OPERATORTYPENAME_CASE(ZerosLike)
     default:
       LOG(FATAL) << "Unhandled op type";
 #undef HANDLE_OPERATORTYPENAME_CASE
diff --git a/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
index a96e2c4..80cdb2f 100644
--- a/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
+++ b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
@@ -36,7 +36,7 @@
       "source": [
         "## Overview\n",
         "\n",
-        "[TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) now supports\n",
+        "[TensorFlow Lite](https://www.tensorflow.org/lite/) now supports\n",
         "converting weights to 8 bit precision as part of model conversion from\n",
         "tensorflow graphdefs to TFLite's flat buffer format. Weight quantization\n",
         "achieves a 4x reduction in the model size. In addition, TFLite supports on the\n",
@@ -542,7 +542,7 @@
       },
       "outputs": [],
       "source": [
-        "print(eval_model(interpreter_quant, mnist_ds))"
+        "print(eval_model(interpreter, mnist_ds))"
       ]
     },
     {
@@ -608,7 +608,8 @@
       "outputs": [],
       "source": [
         "archive_path = tf.keras.utils.get_file(\"resnet_v2_101.tgz\", \"https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz\", extract=True)\n",
-        "archive_path = pathlib.Path(archive_path)"
+        "archive_path = pathlib.Path(archive_path)\n",
+        "archive_dir = str(archive_path.parent)"
       ]
     },
     {
@@ -631,7 +632,7 @@
       },
       "outputs": [],
       "source": [
-        "! cat {archive_path}/resnet_v2_101_299_info.txt"
+        "! cat {archive_dir}/resnet_v2_101_299_info.txt"
       ]
     },
     {
@@ -664,8 +665,8 @@
       },
       "outputs": [],
       "source": [
-        "archive_dir = str(archive_path.parent)\n",
-        "!ls -lh {archive_dir}"
+        "\n",
+        "!ls -lh {archive_dir}/*.tflite"
       ]
     },
     {
diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py
index 4ec539a..9c38914 100644
--- a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py
@@ -61,7 +61,7 @@
 class ContrastiveLossTest(test.TestCase):
 
   def testContrastive(self):
-    with self.test_session():
+    with self.cached_session():
       num_data = 10
       feat_dim = 6
       margin = 1.0
@@ -90,7 +90,7 @@
 class TripletSemiHardLossTest(test.TestCase):
 
   def testTripletSemiHard(self):
-    with self.test_session():
+    with self.cached_session():
       num_data = 10
       feat_dim = 6
       margin = 1.0
@@ -146,7 +146,7 @@
 class LiftedStructLossTest(test.TestCase):
 
   def testLiftedStruct(self):
-    with self.test_session():
+    with self.cached_session():
       num_data = 10
       feat_dim = 6
       margin = 1.0
@@ -217,7 +217,7 @@
 class NpairsLossTest(test.TestCase):
 
   def testNpairs(self):
-    with self.test_session():
+    with self.cached_session():
       num_data = 15
       feat_dim = 6
       num_classes = 5
@@ -261,7 +261,7 @@
 class NpairsLossMultiLabelTest(test.TestCase):
 
   def testNpairsMultiLabelLossWithSingleLabelEqualsNpairsLoss(self):
-    with self.test_session():
+    with self.cached_session():
       num_data = 15
       feat_dim = 6
       reg_lambda = 0.02
@@ -290,7 +290,7 @@
       self.assertAllClose(loss_npairs, loss_npairs_multilabel)
 
   def testNpairsMultiLabel(self):
-    with self.test_session():
+    with self.cached_session():
       num_data = 15
       feat_dim = 6
       num_classes = 10
@@ -527,7 +527,7 @@
   def testClusteringLossPAMOff(self):
     if not HAS_SKLEARN:
       return
-    with self.test_session():
+    with self.cached_session():
       margin_multiplier = 10.0
       embeddings, labels = self._genClusters(n_samples=128, n_clusters=64)
 
@@ -544,7 +544,7 @@
   def testClusteringLossPAMOn(self):
     if not HAS_SKLEARN:
       return
-    with self.test_session():
+    with self.cached_session():
       margin_multiplier = 10.0
       embeddings, labels = self._genClusters(n_samples=128, n_clusters=64)
 
diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
index 1d6d9a6..0d8df93 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
@@ -10,7 +10,6 @@
 tensorflow/core/framework/graph_transfer_info.pb.cc
 tensorflow/core/framework/kernel_def.pb.cc
 tensorflow/core/framework/log_memory.pb.cc
-tensorflow/core/framework/model.pb.cc
 tensorflow/core/framework/node_def.pb.cc
 tensorflow/core/framework/op_def.pb.cc
 tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
index 884461e..d982df9 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
@@ -10,7 +10,6 @@
 tensorflow/core/framework/graph_transfer_info.pb.h
 tensorflow/core/framework/kernel_def.pb.h
 tensorflow/core/framework/log_memory.pb.h
-tensorflow/core/framework/model.pb.h
 tensorflow/core/framework/node_def.pb.h
 tensorflow/core/framework/op_def.pb.h
 tensorflow/core/framework/remote_fused_graph_execute_info.pb.h
diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt
index e23f499..f94d70d 100644
--- a/tensorflow/contrib/makefile/tf_pb_text_files.txt
+++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt
@@ -10,7 +10,6 @@
 tensorflow/core/framework/graph_transfer_info.pb_text.cc
 tensorflow/core/framework/kernel_def.pb_text.cc
 tensorflow/core/framework/log_memory.pb_text.cc
-tensorflow/core/framework/model.pb_text.cc
 tensorflow/core/framework/node_def.pb_text.cc
 tensorflow/core/framework/op_def.pb_text.cc
 tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc
diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt
index 5eae845..8bec3e3 100644
--- a/tensorflow/contrib/makefile/tf_proto_files.txt
+++ b/tensorflow/contrib/makefile/tf_proto_files.txt
@@ -14,7 +14,6 @@
 tensorflow/core/framework/graph_transfer_info.proto
 tensorflow/core/framework/kernel_def.proto
 tensorflow/core/framework/log_memory.proto
-tensorflow/core/framework/model.proto
 tensorflow/core/framework/node_def.proto
 tensorflow/core/framework/op_def.proto
 tensorflow/core/framework/reader_base.proto
diff --git a/tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py b/tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py
index 1d18d6b..bed1ecb 100644
--- a/tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py
+++ b/tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py
@@ -31,21 +31,21 @@
   """Test this private function."""
 
   def test_empty_tensor_returns_empty(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = constant_op.constant([])
       result = histogram_ops._strict_1d_cumsum(tensor, 0)
       expected = constant_op.constant([])
       np.testing.assert_array_equal(expected.eval(), result.eval())
 
   def test_length_1_tensor_works(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = constant_op.constant([3], dtype=dtypes.float32)
       result = histogram_ops._strict_1d_cumsum(tensor, 1)
       expected = constant_op.constant([3], dtype=dtypes.float32)
       np.testing.assert_array_equal(expected.eval(), result.eval())
 
   def test_length_3_tensor_works(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = constant_op.constant([1, 2, 3], dtype=dtypes.float32)
       result = histogram_ops._strict_1d_cumsum(tensor, 3)
       expected = constant_op.constant([1, 3, 6], dtype=dtypes.float32)
@@ -58,7 +58,7 @@
     self.rng = np.random.RandomState(0)
 
   def test_empty_labels_and_scores_gives_nan_auc(self):
-    with self.test_session():
+    with self.cached_session():
       labels = constant_op.constant([], shape=[0], dtype=dtypes.bool)
       scores = constant_op.constant([], shape=[0], dtype=dtypes.float32)
       score_range = [0, 1.]
@@ -155,7 +155,7 @@
         from synthetic data.
     """
     score_range = [0, 1.] or score_range
-    with self.test_session():
+    with self.cached_session():
       labels = array_ops.placeholder(dtypes.bool, shape=[num_records])
       scores = array_ops.placeholder(dtypes.float32, shape=[num_records])
       auc, update_op = histogram_ops.auc_using_histogram(
diff --git a/tensorflow/contrib/metrics/python/metrics/classification_test.py b/tensorflow/contrib/metrics/python/metrics/classification_test.py
index 3d0b81c..d6a670f 100644
--- a/tensorflow/contrib/metrics/python/metrics/classification_test.py
+++ b/tensorflow/contrib/metrics/python/metrics/classification_test.py
@@ -34,7 +34,7 @@
 class ClassificationTest(test.TestCase):
 
   def testAccuracy1D(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       pred = array_ops.placeholder(dtypes.int32, shape=[None])
       labels = array_ops.placeholder(dtypes.int32, shape=[None])
       acc = classification.accuracy(pred, labels)
@@ -44,7 +44,7 @@
       self.assertEqual(result, 0.5)
 
   def testAccuracy1DBool(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       pred = array_ops.placeholder(dtypes.bool, shape=[None])
       labels = array_ops.placeholder(dtypes.bool, shape=[None])
       acc = classification.accuracy(pred, labels)
@@ -54,7 +54,7 @@
       self.assertEqual(result, 0.5)
 
   def testAccuracy1DInt64(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       pred = array_ops.placeholder(dtypes.int64, shape=[None])
       labels = array_ops.placeholder(dtypes.int64, shape=[None])
       acc = classification.accuracy(pred, labels)
@@ -64,7 +64,7 @@
       self.assertEqual(result, 0.5)
 
   def testAccuracy1DString(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       pred = array_ops.placeholder(dtypes.string, shape=[None])
       labels = array_ops.placeholder(dtypes.string, shape=[None])
       acc = classification.accuracy(pred, labels)
@@ -87,7 +87,7 @@
       classification.accuracy(pred, labels)
 
   def testAccuracy1DWeighted(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       pred = array_ops.placeholder(dtypes.int32, shape=[None])
       labels = array_ops.placeholder(dtypes.int32, shape=[None])
       weights = array_ops.placeholder(dtypes.float32, shape=[None])
@@ -101,7 +101,7 @@
       self.assertEqual(result, 0.5)
 
   def testAccuracy1DWeightedBroadcast(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       pred = array_ops.placeholder(dtypes.int32, shape=[None])
       labels = array_ops.placeholder(dtypes.int32, shape=[None])
       weights = array_ops.placeholder(dtypes.float32, shape=[])
@@ -161,7 +161,7 @@
         (10, 3), maxval=2, dtype=dtypes.int64, seed=2)
     f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -176,7 +176,7 @@
   def testAllCorrect(self):
     inputs = np.random.randint(0, 2, size=(100, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes.float32)
       labels = constant_op.constant(inputs)
       f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
@@ -191,7 +191,7 @@
         [1, 0, 1, 0], shape=(1, 4), dtype=dtypes.float32)
     labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
     f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=1)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run([f1_op])
       # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1
@@ -201,7 +201,7 @@
   def testAllIncorrect(self):
     inputs = np.random.randint(0, 2, size=(10000, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes.float32)
       labels = constant_op.constant(1 - inputs, dtype=dtypes.float32)
       f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
@@ -214,7 +214,7 @@
       self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval(), places=2)
 
   def testWeights1d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32)
       labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -228,7 +228,7 @@
       self.assertAlmostEqual(1.0, f1.eval(), places=5)
 
   def testWeights2d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32)
       labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -242,7 +242,7 @@
       self.assertAlmostEqual(1.0, f1.eval(), places=5)
 
   def testZeroLabelsPredictions(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = array_ops.zeros([4], dtype=dtypes.float32)
       labels = array_ops.zeros([4])
       f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
@@ -300,7 +300,7 @@
     f1, f1_op = classification.f1_score(tf_labels, tf_predictions,
                                         num_thresholds=3)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       for _ in range(num_batches):
         sess.run([f1_op])
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
index f08ffaa..089ecf5 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
@@ -236,7 +236,7 @@
                              opt.get_slot(var=var0, name="m").name)
 
   def testBasic(self):
-    with self.test_session():
+    with self.cached_session():
       self.doTestBasic(use_resource=False)
 
   @test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -249,7 +249,7 @@
 
   def testTensorLearningRate(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         # Initialize variables for numpy implementation.
         m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
         var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -286,7 +286,7 @@
 
   def testSharing(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         # Initialize variables for numpy implementation.
         m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
         var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py
index 25ec475..dab1e02 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad.py
@@ -31,7 +31,7 @@
 
   See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
   or this
-  [intro](http://cs.stanford.edu/~ppasupat/a9online/uploads/proximal_notes.pdf).
+  [intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
   """
 
   def __init__(self, learning_rate, initial_accumulator_value=0.1,
diff --git a/tensorflow/contrib/quantization/README.md b/tensorflow/contrib/quantization/README.md
index 359950a..826e8db 100644
--- a/tensorflow/contrib/quantization/README.md
+++ b/tensorflow/contrib/quantization/README.md
@@ -2,6 +2,6 @@
 
 If you are looking for quantized training rewrites that allow for training
 quantized models that work with
-[TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/), you should look at
+[TensorFlow Lite](https://www.tensorflow.org/lite/), you should look at
 the [contrib/quantize](https://www.tensorflow.org/api_docs/python/tf/contrib/quantize)
 package.
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index c59f667..23e3a25 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -20,9 +20,13 @@
     srcs_version = "PY2AND3",
     deps = [
         ":common",
+        "//tensorflow/contrib/layers:layers_py",
+        "//tensorflow/python:array_ops",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:init_ops",
         "//tensorflow/python:math_ops",
+        "//tensorflow/python:nn_ops",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:session",
         "//tensorflow/python:variable_scope",
diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md
index 27a933c..0ab19c9 100644
--- a/tensorflow/contrib/quantize/README.md
+++ b/tensorflow/contrib/quantize/README.md
@@ -1,65 +1,155 @@
-# Quantized Training Rewrites
+# Quantization-aware training
 
-tf.contrib.quantize provides tools for transforming graphs to include ops to
-model quantization of weights, biases and activations during both training and
-inference. The details of the transformation implemented in this package is
-described here [1].
+Quantization-aware model training ensures that the forward pass matches precision
+for both training and inference. There are two aspects to this:
 
-This is done using the
-[fake quantization op](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization).
+* Operator fusion at inference time are accurately modeled at training time.
+* Quantization effects at inference are modeled at training time.
 
-Literature has shown that fixed point networks provide comparable performance to
-floating point networks [2]. This is achieved by modeling the quantization
-operation during training in both the forward and backward passes.
-The fake quantization operator achieves this by modeling the quantizer as a pass
-through estimator [3]. Note that during back propagation, the parameters are
+For efficient inference, TensorFlow combines batch normalization with the preceding
+convolutional and fully-connected layers prior to quantization by
+[folding batch norm layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/python/fold_batch_norms.py){:.external}. 
+
+The quantization error is modeled using [fake quantization](../api_guides/python/array_ops.md#Fake_quantization)
+nodes to simulate the effect of quantization in the forward and backward passes. The
+forward-pass models quantization, while the backward-pass models quantization as a
+straight-through estimator. Both the forward- and backward-pass simulate the quantization
+of weights and activations. Note that during back propagation, the parameters are
 updated at high precision as this is needed to ensure sufficient precision in
-accumulating tiny adjustments to the parameters. However, for the forward pass,
-the parameters and activations are quantized to the desired lower precision.
+accumulating tiny adjustments to the parameters.
 
-## How to use the Rewrites
 
-tf.contrib.quantize provides two rewrites, one to train for quantization and
-one to create a [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/)
-compatible eval graph.
+Additionally, the minimum and maximum values for activations are determined
+during training. This allows a model trained with quantization in the loop to be
+converted to a fixed point inference model with little effort, eliminating the
+need for a separate calibration step.
+
+Since it's difficult to add these fake quantization operations to all the
+required locations in the model, there's a function available that rewrites the
+training graph. To create a fake quantized training graph:
 
 ```
 # Build forward pass of model.
-…
 loss = tf.losses.get_total_loss()
 
-# Call the training rewrite which rewrites the graph in-place with FakeQuantization nodes
-# and folds batchnorm for training.
-# It is often needed to finetune a floating point model for quantization with this training tool.
-# When training from scratch, quant_delay can be used to activate quantization after
-# training to convergence with the float graph, effectively finetuning the model.
-tf.contrib.quantize.create_training_graph(quant_delay=2000000)
+# Call the training rewrite which rewrites the graph in-place with
+# FakeQuantization nodes and folds batchnorm for training. It is
+# often needed to fine tune a floating point model for quantization
+# with this training tool. When training from scratch, quant_delay
+# can be used to activate quantization after training to converge
+# with the float graph, effectively fine-tuning the model.
+g = tf.get_default_graph()
+tf.contrib.quantize.create_training_graph(input_graph=g,
+                                          quant_delay=2000000)
 
 # Call backward pass optimizer as usual.
 optimizer = tf.train.GradientDescentOptimizer(learning_rate)
 optimizer.minimize(loss)
 ```
 
-Additionally, the rewritten eval graph is non-trivially different from the
-training graph due the effects of quantization on batch normalization. Thus,
-we offer a separate rewrite for the eval_graph.
+The rewritten *eval graph* is non-trivially different from the *training graph*
+since the quantization ops affect the batch normalization step. Because of this,
+we've added a separate rewrite for the *eval graph*:
 
 ```
 # Build eval model
-…
-logits = tf.nn.softmax_cross_entropy_with_logits(...)
+logits = tf.nn.softmax_cross_entropy_with_logits_v2(...)
 
-# Call the eval rewrite which rewrites the graph in-place with FakeQuantization nodes
-# and fold batchnorm for eval.
-tf.contrib.quantize.create_eval_graph()
+# Call the eval rewrite which rewrites the graph in-place with
+# FakeQuantization nodes and fold batchnorm for eval.
+g = tf.get_default_graph()
+tf.contrib.quantize.create_eval_graph(input_graph=g)
 
-# Save the checkpoint and eval graph proto to disk for freezing and providing to TFLite.
+# Save the checkpoint and eval graph proto to disk for freezing
+# and providing to TFLite.
 with open(eval_graph_file, ‘w’) as f:
   f.write(str(g.as_graph_def()))
 saver = tf.train.Saver()
 saver.save(sess, checkpoint_name)
 ```
 
+Methods to rewrite the training and eval graphs are an active area of research
+and experimentation. Although rewrites and quantized training might not work or
+improve performance for all models, we are working to generalize these techniques.
+
+
+## Generating fully-quantized models
+
+The previously demonstrated after-rewrite eval graph only *simulates*
+quantization. To generate real fixed-point computations from a trained
+quantization model, convert it to a fixed-point kernel. TensorFlow Lite supports
+this conversion from the graph resulting from `create_eval_graph`.
+
+First, create a frozen graph that will be the input for the TensorFlow Lite
+toolchain:
+
+```
+freeze_graph \
+  --input_graph=eval_graph_def.pb \
+  --input_checkpoint=checkpoint \
+  --output_graph=frozen_eval_graph.pb --output_node_names=outputs
+```
+
+Provide this to the TensorFlow Lite Optimizing Converter (TOCO) to get a
+fully-quantized TensorFlow Lite model:
+
+```
+toco \
+  --input_file=frozen_eval_graph.pb \
+  --output_file=tflite_model.tflite \
+  --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
+  --inference_type=QUANTIZED_UINT8 \
+  --input_shape="1,224, 224,3" \
+  --input_array=input \
+  --output_array=outputs \
+  --std_value=127.5 --mean_value=127.5
+```
+
+See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../lite/).
+
+
+## Quantized accuracy results
+
+The following are results of trainiing some popular CNN models (Mobilenet-v1,
+Mobilenet-v2, and Inception-v3) using this tool:
+
+<figure>
+  <table>
+    <tr>
+      <th>Model</th>
+      <th>Top-1 Accuracy:<br>Floating point</th>
+      <th>Top-1 Accuracy:<br>Fixed point: 8 bit weights and activations</th>
+    </tr>
+    <tr><td>Mobilenet-v1-128-0.25</td><td>0.415</td><td>0.399</td></tr>
+    <tr><td>Mobilenet-v1-128-0.5</td><td>0.563</td><td>0.549</td></tr>
+    <tr><td>Mobilenet-v1-128-0.75</td><td>0.621</td><td>0.598</td></tr>
+    <tr><td>Mobilenet-v1-128-1</td><td>0.652</td><td>0.64</td></tr>
+    <tr><td>Mobilenet-v1-160-0.25</td><td>0.455</td><td>0.435</td></tr>
+    <tr><td>Mobilenet-v1-160-0.5</td><td>0.591</td><td>0.577</td></tr>
+    <tr><td>Mobilenet-v1-160-0.75</td><td>0.653</td><td>0.639</td></tr>
+    <tr><td>Mobilenet-v1-160-1</td><td>0.68</td><td>0.673</td></tr>
+    <tr><td>Mobilenet-v1-192-0.25</td><td>0.477</td><td>0.458</td></tr>
+    <tr><td>Mobilenet-v1-192-0.5</td><td>0.617</td><td>0.604</td></tr>
+    <tr><td>Mobilenet-v1-192-0.75</td><td>0.672</td><td>0.662</td></tr>
+    <tr><td>Mobilenet-v1-192-1</td><td>0.7</td><td>0.69</td></tr>
+    <tr><td>Mobilenet-v1-224-0.25</td><td>0.498</td><td>0.482</td></tr>
+    <tr><td>Mobilenet-v1-224-0.5</td><td>0.633</td><td>0.622</td></tr>
+    <tr><td>Mobilenet-v1-224-0.75</td><td>0.684</td><td>0.679</td></tr>
+    <tr><td>Mobilenet-v1-224-1</td><td>0.709</td><td>0.697</td></tr>
+    <tr><td>Mobilenet-v2-224-1</td><td>0.718</td><td>0.708</td></tr>
+   <tr><td>Inception_v3</td><td>0.78</td><td>0.775</td></tr>
+  </table>
+  <figcaption>
+    <b>Table 1</b>: Top-1 accuracy of floating point and fully quantized CNNs on Imagenet Validation dataset.
+  </figcaption>
+</figure>
+
+Our pre-trained models are available in the
+<a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md#image-classification-quantized-models" class="external">TensorFlow Lite model repository</a>. The code used to generate
+these models <a href="https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1_train.py" class="external">is available</a>.
+
+
+
 These rewrites are an active area of research and experimentation, so the
 rewrites and quantized training will likely not work across all models, though
 we hope to work towards generalizing these techniques.
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py
index b27117d..e6c04bc 100644
--- a/tensorflow/contrib/quantize/python/common.py
+++ b/tensorflow/contrib/quantize/python/common.py
@@ -34,10 +34,10 @@
     'ScalarSummary')
 
 # Valid activation ops for quantization end points.
-_ACTIVATION_OP_SUFFIXES = ['/Relu6', '/Relu', '/Identity']
+_ACTIVATION_OP_SUFFIXES = ['Relu6', 'Relu', 'Identity']
 
 # Regular expression for recognizing nodes that are part of batch norm group.
-_BATCHNORM_RE = re.compile(r'^(.*)/BatchNorm/batchnorm')
+_BATCHNORM_RE = re.compile(r'^(.*)BatchNorm/batchnorm')
 
 
 def BatchNormGroups(graph):
diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py
index 2b26302..a3ce041 100644
--- a/tensorflow/contrib/quantize/python/common_test.py
+++ b/tensorflow/contrib/quantize/python/common_test.py
@@ -13,21 +13,26 @@
 # limitations under the License.
 # ==============================================================================
 """Tests for common utilities in this package."""
-
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
-
+from tensorflow.contrib.layers.python.layers import layers
 from tensorflow.contrib.quantize.python import common
 from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
 
+batch_norm = layers.batch_norm
+conv2d = layers.conv2d
+
 
 class CommonTest(test_util.TensorFlowTestCase):
 
@@ -87,6 +92,56 @@
     for i in inputs:
       self.assertIn(i, op.inputs)
 
+  def testBatchNormScope(self):
+    batch_size, height, width, depth = 5, 128, 128, 3
+    g = ops.Graph()
+    with g.as_default():
+      inputs = array_ops.zeros((batch_size, height, width, depth))
+      stride = 1
+      out_depth = 32
+      scope = ''
+      node = conv2d(
+          inputs,
+          out_depth, [2, 2],
+          stride=stride,
+          padding='SAME',
+          weights_initializer=self._WeightInit(0.09),
+          activation_fn=None,
+          normalizer_fn=batch_norm,
+          normalizer_params=self._BatchNormParams(False),
+          scope=scope)
+
+      node = nn_ops.relu(node, name='Relu6')
+    bn_list = common.BatchNormGroups(g)
+    with open('/tmp/common_test.pbtxt', 'w') as f:
+      f.write(str(g.as_graph_def()))
+
+  # Exactly one batch norm layer with empty scope should be found
+    self.assertEqual(len(bn_list), 1)
+    self.assertEqual(bn_list[0], '')
+
+  def _BatchNormParams(self, fused=False, force_updates=False):
+    params = {
+        'center': True,
+        'scale': True,
+        'decay': 1.0 - 0.003,
+        'fused': fused
+    }
+    return params
+
+  def _WeightInit(self, stddev):
+    """Returns a truncated normal variable initializer.
+
+    Function is defined purely to shorten the name so that it stops wrapping.
+
+    Args:
+      stddev: Standard deviation of normal variable.
+
+    Returns:
+      An initializer that initializes with a truncated normal variable.
+    """
+    return init_ops.truncated_normal_initializer(stddev=stddev, seed=1234)
+
 
 if __name__ == '__main__':
   googletest.main()
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index 2971b28..e5790a6 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -95,8 +95,7 @@
               _ComputeBatchNormCorrections(
                   context='',
                   match=match,
-                  freeze_batch_norm_delay=freeze_batch_norm_delay,
-                  fused_batch_norm=True))
+                  freeze_batch_norm_delay=freeze_batch_norm_delay))
         # The shape of depthwise weights is different, so we need to reshape the
         # multiplier_tensor to ensure that the scaled_weight_tensor has the
         # expected shape.
@@ -296,8 +295,7 @@
         batch_to_space_op=batch_to_space_op)
 
 
-def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
-                                 fused_batch_norm):
+def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay):
   """Computes batch norm correction params.
 
      Before batch normalization is frozen:
@@ -327,14 +325,14 @@
       computation.
     freeze_batch_norm_delay: Delay in steps at which computation switches
       from regular batch norm to frozen mean and variance.
-    fused_batch_norm: Bool, true if fused batch norm is used.
+
 
   Returns:
     A tuple of correction_scale, correction_recip, correction_offset
   """
 
   g = ops.get_default_graph()
-  prefix = '' if not context else context + '/'
+  prefix = '' if not context else context
   with g.name_scope(prefix + 'batch_norm_correction'):
     recip_sigma_mv = math_ops.rsqrt(
         match.moving_variance_tensor + match.batch_epsilon)
@@ -495,8 +493,23 @@
 
     # Treat consumer ops in bypass modules differently since they have Add
     # operations instead of Relu* above.
-    add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
-    add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
+    # Changes to make sure that the correct scope is selected for the bypass add
+    # The rule here is that if the scope is of the form: str1/str2 for the
+    # batch norm,
+    # the bypass add is at scope str1. If bn is of scope just str1, then the
+    # bypass add is at scope ''.
+    # If there is no batch norm, then there is no bypass add.
+    add_bypass_ctx = ''
+    if bn:
+      try:
+        add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
+      except AttributeError:
+        add_bypass_ctx = ''
+
+    if add_bypass_ctx:
+      add_bypass_ctx = add_bypass_ctx + '/'
+
+    add_bypass = graph.get_operation_by_name(add_bypass_ctx + 'Add')
     nodes_modified_count = common.RerouteTensor(
         folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass])
     if nodes_modified_count != 1:
@@ -505,8 +518,8 @@
 
 def _IsValidUnfusedBatchNorm(graph, context):
   """Checks that the output of the unfused batch norm has consumers."""
-  add_shift = graph.get_operation_by_name(
-      context + '/BatchNorm/batchnorm_1/add_1')
+  add_shift = graph.get_operation_by_name(context +
+                                          'BatchNorm/batchnorm_1/add_1')
   # Ensure that the output tensor of batch norm has consumers, otherwise this
   # is a dangling node and not a match.
   return bool(add_shift.outputs[0].consumers())
@@ -538,7 +551,8 @@
     if op.name.endswith(match_pattern):
       split_name = op.name.split('/')
       num_matches = len(set(split_name) & split_context)
-      if num_matches > 0:
+
+      if num_matches > 0 or not scope:
         match_dict[op.name] = num_matches
   # match_dict contains matching op names from graph with values being
   # number of matches to scope. We pick the key with the most matches
@@ -597,21 +611,21 @@
   # op.name =  MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read
   # will have 2 matches,scope with a different conv layer will have one match.
 
-  op_suffix_mean = '/BatchNorm/moments/Squeeze'
-  op_suffix_variance = '/BatchNorm/moments/Squeeze_1'
-  op_suffix_epsilon = '/BatchNorm/batchnorm_1/add/y'
-  op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay'
-  op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay'
+  op_suffix_mean = 'BatchNorm/moments/Squeeze'
+  op_suffix_variance = 'BatchNorm/moments/Squeeze_1'
+  op_suffix_epsilon = 'BatchNorm/batchnorm_1/add/y'
+  op_suffix_bn_decay_mean = 'BatchNorm/AssignMovingAvg/decay'
+  op_suffix_bn_decay_var = 'BatchNorm/AssignMovingAvg_1/decay'
 
   if variable_scope.get_variable_scope().use_resource:
-    op_suffix_gamma = '/BatchNorm/gamma/Read/ReadVariableOp'
+    op_suffix_gamma = 'BatchNorm/gamma/Read/ReadVariableOp'
     op_suffix_moving_variance = (
-        '/BatchNorm/moving_variance/Read/ReadVariableOp')
-    op_suffix_moving_mean = ('/BatchNorm/moving_mean/Read/ReadVariableOp')
+        'BatchNorm/moving_variance/Read/ReadVariableOp')
+    op_suffix_moving_mean = ('BatchNorm/moving_mean/Read/ReadVariableOp')
   else:
-    op_suffix_gamma = '/BatchNorm/gamma'
-    op_suffix_moving_variance = '/BatchNorm/moving_variance/read'
-    op_suffix_moving_mean = '/BatchNorm/moving_mean/read'
+    op_suffix_gamma = 'BatchNorm/gamma'
+    op_suffix_moving_variance = 'BatchNorm/moving_variance/read'
+    op_suffix_moving_mean = 'BatchNorm/moving_mean/read'
   # Parse through list of ops to find relevant ops
 
   batch_mean_tensor = _FindMatchingTensor(graph, op_suffix_mean, context)
@@ -679,8 +693,7 @@
       the folded graph (add_fold).
   """
   mul_scale_name = 'mul_1' if has_scaling else 'mul'
-  mul_scale = graph.get_operation_by_name(context +
-                                          '/BatchNorm/batchnorm_1/' +
+  mul_scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' +
                                           mul_scale_name)
   op_below = mul_scale.inputs[0].op
   # Skip over the BatchToSpace operation in the case of atrous convolutions.
@@ -697,8 +710,7 @@
         _ComputeBatchNormCorrections(
             context=context,
             match=match,
-            freeze_batch_norm_delay=freeze_batch_norm_delay,
-            fused_batch_norm=False))
+            freeze_batch_norm_delay=freeze_batch_norm_delay))
   # Special handling for weights of depthwise convolution.
   if op_below.type == 'DepthwiseConv2dNative':
     new_shape = [
@@ -706,27 +718,27 @@
         weights.get_shape().as_list()[3]
     ]
     scale_name = 'mul' if has_scaling else 'Rsqrt'
-    scale = graph.get_operation_by_name(
-        context + '/BatchNorm/batchnorm_1/' + scale_name)
+    scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' +
+                                        scale_name)
     scale = array_ops.reshape(scale.outputs[0], new_shape,
-                              context + '/scale_reshape')
+                              context + 'scale_reshape')
 
     if correction_scale is not None:
       correction_scale = array_ops.reshape(correction_scale, new_shape,
-                                           context + '/correction_reshape')
+                                           context + 'correction_reshape')
       with ops.device(mul_scale.device):
         weights = math_ops.multiply(correction_scale, weights,
-                                    context + '/correction_mult')
+                                    context + 'correction_mult')
 
-    mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights),
-                                                           (1, scale)])
+    mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights),
+                                                          (1, scale)])
   elif op_below.type in ['Conv2D', 'MatMul']:
 
     if correction_scale is not None:
       with ops.device(mul_scale.device):
         weights = math_ops.multiply(correction_scale, weights,
-                                    context + '/correction_mult')
-    mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)])
+                                    context + 'correction_mult')
+    mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights)])
   else:
     raise ValueError('Cannot handle operation of type: %s' % op_below.type)
   _AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0])
@@ -734,8 +746,8 @@
   conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold',
                                [(1, mul_fold.outputs[0])])
 
-  add_shift = graph.get_operation_by_name(
-      context + '/BatchNorm/batchnorm_1/add_1')
+  add_shift = graph.get_operation_by_name(context +
+                                          'BatchNorm/batchnorm_1/add_1')
 
   corrected_output = conv_or_fc_folded.outputs[0]
   # Copy the batch to space operation if we have a atrous convolution.
@@ -748,10 +760,10 @@
   if correction_offset is not None:
     with ops.device(conv_or_fc_folded.device):
       corrected_output = math_ops.multiply(correction_recip, corrected_output,
-                                           context + '/post_conv_mul')
+                                           context + 'post_conv_mul')
       corrected_output = math_ops.add(corrected_output, (correction_offset),
-                                      context + '/correction_add')
-  add_fold = _CloneOp(add_shift, context + '/add_fold', [(0, corrected_output)])
+                                      context + 'correction_add')
+  add_fold = _CloneOp(add_shift, context + 'add_fold', [(0, corrected_output)])
   _AssertShapesMatch('add_fold', add_fold.inputs[0], add_fold.outputs[0])
   return add_shift, add_fold
 
@@ -930,7 +942,7 @@
   Returns:
     A boolean indicating whether this batch norm layer has scaling enabled.
   """
-  rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm_1/Rsqrt')
+  rsqrt_op = graph.get_operation_by_name(bn + 'BatchNorm/batchnorm_1/Rsqrt')
   rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op)
 
   return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index e88db0a..5e63d33 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -97,8 +97,11 @@
         layer_match.activation_op)
     add_context = context
     if layer_match.bypass_op:
-      add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
-
+      pattern_match_result = re.search(r'^(.*)/([^/]+)', context)
+      if pattern_match_result is not None:
+        add_context = pattern_match_result.group(1)
+      else:
+        add_context = ''
     # If `scope` is given, only quantize it if the producer of weights
     # (usually it's the layer op) is in the right scope.
     _InsertQuantOp(
@@ -156,8 +159,12 @@
 
     # Quantize bypass ops that occur after the activation.
     if layer_match.post_activation_bypass_op is not None:
-      post_activation_bypass_context = re.search(
-          r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name).group(1)
+      pattern_match_result = re.search(
+          r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name)
+      if pattern_match_result is not None:
+        post_activation_bypass_context = pattern_match_result.group(1)
+      else:
+        post_activation_bypass_context = ''
       # If `scope` is given, only quantize it if the producer is in the right
       # scope.
       # Make sure the op following this isn't an activation. In which case, we
diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
index 31a2955..f6bf57a 100644
--- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
@@ -58,85 +58,102 @@
     ]
     for params in parameters_list:
       # Test everything with resource variables and normal variables.
-      test_fn(params[0], params[1], params[2], params[3], False)
-      test_fn(params[0], params[1], params[2], params[3], True)
+      test_fn(params[0], params[1], params[2], params[3], False, None)
+      test_fn(params[0], params[1], params[2], params[3], True, None)
+      # Test with both empty scope and an example scope
+      test_fn(params[0], params[1], params[2], params[3], False, 'test')
+      test_fn(params[0], params[1], params[2], params[3], True, 'test')
 
   def _AssertCorrectQuantizedGraphWithoutBatchNorm(
       self, graph, scope, layer, activation_op_name, with_bypass, delay,
       use_resource):
     quantization_node_name = 'FakeQuantWithMinMaxVars'
-    weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
-                                                quantization_node_name)
+    conv_scope = self._GetConvScope(scope, with_bypass)
+    delim = '/' if conv_scope else ''
+
+    if scope:
+      scope = scope + '/'
+    weights_quant = graph.get_operation_by_name(
+        conv_scope + delim + 'weights_quant/' + quantization_node_name)
     self.assertEqual(weights_quant.type, quantization_node_name)
 
     # Assemble the expected inputs.
     if use_resource:
       expected_inputs = [
-          scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
-          scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+          conv_scope + delim +
+          'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+          conv_scope + delim +
+          'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
       ]
       if layer == 'DepthwiseConv2dNative':
-        expected_inputs.append(scope + '/depthwise/ReadVariableOp')
+        expected_inputs.append(conv_scope + delim + 'depthwise/ReadVariableOp')
       else:
-        expected_inputs.append(scope + '/' + layer + '/ReadVariableOp')
+        expected_inputs.append(conv_scope + delim + layer + '/ReadVariableOp')
     else:
       expected_inputs = [
-          scope + '/weights_quant/AssignMinLast',
-          scope + '/weights_quant/AssignMaxLast',
+          conv_scope + delim + 'weights_quant/AssignMinLast',
+          conv_scope + delim + 'weights_quant/AssignMaxLast',
       ]
       if layer == 'DepthwiseConv2dNative':
-        expected_inputs.append(scope + '/depthwise_weights/read')
+        expected_inputs.append(conv_scope + delim + 'depthwise_weights/read')
       else:
-        expected_inputs.append(scope + '/weights/read')
+        expected_inputs.append(conv_scope + delim + 'weights/read')
 
     self._AssertInputOpsAre(weights_quant, expected_inputs)
     if delay and delay > 0:
-      output_op_name = scope + '/weights_quant/delayed_quant/Switch_1'
+      output_op_name = (
+          conv_scope + delim + 'weights_quant/delayed_quant/Switch_1')
     else:
       if layer == 'DepthwiseConv2dNative':
-        output_op_name = scope + '/depthwise'
+        output_op_name = conv_scope + delim + 'depthwise'
       else:
-        output_op_name = scope + '/' + layer
+        output_op_name = conv_scope + delim + layer
 
     self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
 
     if with_bypass:
-      conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' +
-                                               quantization_node_name)
+      conv_quant = graph.get_operation_by_name(
+          conv_scope + delim + 'conv_quant/' + quantization_node_name)
       self.assertEqual(conv_quant.type, quantization_node_name)
       if use_resource:
         expected_inputs = [
-            scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
-            scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
-            scope + '/BiasAdd',
+            conv_scope + delim +
+            'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+            conv_scope + delim +
+            'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+            conv_scope + delim + 'BiasAdd',
         ]
       else:
         expected_inputs = [
-            scope + '/conv_quant/AssignMinEma',
-            scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
+            conv_scope + delim + 'conv_quant/AssignMinEma',
+            conv_scope + delim + 'conv_quant/AssignMaxEma',
+            conv_scope + delim + 'BiasAdd'
         ]
       self._AssertInputOpsAre(conv_quant, expected_inputs)
-      output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
-                        if delay else 'test/Add')
+
+      output_op_name = (
+          conv_scope + delim + 'conv_quant/delayed_quant/Switch_1'
+          if delay else scope + 'Add')
       self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
 
-    act_quant = graph.get_operation_by_name('test/act_quant/' +
+    act_quant = graph.get_operation_by_name(scope + 'act_quant/' +
                                             quantization_node_name)
     self.assertEqual(act_quant.type, quantization_node_name)
     if use_resource:
       expected_inputs = [
-          'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
-          'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
-          'test/' + activation_op_name,
+          scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+          scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+          scope + activation_op_name,
       ]
     else:
       expected_inputs = [
-          'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
-          'test/' + activation_op_name
+          scope + 'act_quant/AssignMinEma', scope + 'act_quant/AssignMaxEma',
+          scope + activation_op_name
       ]
     self._AssertInputOpsAre(act_quant, expected_inputs)
-    output_op_name = ('test/act_quant/delayed_quant/Switch_1'
-                      if delay else 'control_dependency')
+    output_op_name = (
+        scope + 'act_quant/delayed_quant/Switch_1'
+        if delay else 'control_dependency')
     self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
     self._AssertIdempotent(graph)
 
@@ -145,7 +162,8 @@
         self._TestQuantize_Conv2dWithoutBatchNorm)
 
   def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name,
-                                           with_bypass, delay, use_resource):
+                                           with_bypass, delay, use_resource,
+                                           scope):
     """Tests quantization: inputs -> Conv2d no batch norm -> Activation.
 
     Args:
@@ -156,6 +174,7 @@
         inputs to just before Activation.
       delay: Int (optional), delay in number of steps until quantization starts.
       use_resource: Bool, when true uses resource variables.
+      scope: String, specifies top level scope for the graph
     """
     graph = ops.Graph()
     with graph.as_default():
@@ -165,7 +184,9 @@
       stride = 1 if with_bypass else 2
       out_depth = 3 if with_bypass else 32
       activation_fn = None if with_bypass else activation
-      scope = 'test/test2' if with_bypass else 'test'
+      conv_scope = self._GetConvScope(scope, with_bypass)
+      scope = '' if scope is None else scope
+      delim = '/' if scope else ''
       node = conv2d(
           inputs,
           out_depth, [5, 5],
@@ -173,16 +194,19 @@
           padding='SAME',
           weights_initializer=self._WeightInit(0.09),
           activation_fn=activation_fn,
-          scope=scope)
+          scope=conv_scope)
       if with_bypass:
-        node = math_ops.add(inputs, node, name='test/Add')
-        node = activation(node, name='test/' + activation_op_name)
+        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+        node = activation(node, name=scope + delim + activation_op_name)
       update_barrier = control_flow_ops.no_op(name='update_barrier')
       with ops.control_dependencies([update_barrier]):
         array_ops.identity(node, name='control_dependency')
 
       quantize.Quantize(graph, True, quant_delay=delay)
 
+    if conv_scope is None:
+      conv_scope = ''
+
     self._AssertCorrectQuantizedGraphWithoutBatchNorm(
         graph, scope, 'Conv2D', activation_op_name, with_bypass, delay,
         use_resource)
@@ -192,7 +216,7 @@
         self._TestQuantize_FCWithoutBatchNorm)
 
   def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name,
-                                       with_bypass, delay, use_resource):
+                                       with_bypass, delay, use_resource, scope):
     """Tests quantization: inputs -> FC no batch norm -> Activation.
 
     Args:
@@ -203,6 +227,7 @@
         inputs to just before Activation.
       delay: Int (optional), delay in number of steps until quantization starts.
       use_resource: Bool, when true uses resource variables.
+      scope: String, specifies top level scope for the graph
     """
     graph = ops.Graph()
     with graph.as_default():
@@ -211,16 +236,18 @@
       inputs = array_ops.zeros((batch_size, depth))
       out_depth = 256 if with_bypass else 128
       activation_fn = None if with_bypass else activation
-      scope = 'test/test2' if with_bypass else 'test'
+      fc_scope = self._GetConvScope(scope, with_bypass)
+      scope = '' if scope is None else scope
+      delim = '/' if scope else ''
       node = fully_connected(
           inputs,
           out_depth,
           weights_initializer=self._WeightInit(0.03),
           activation_fn=activation_fn,
-          scope=scope)
+          scope=fc_scope)
       if with_bypass:
-        node = math_ops.add(inputs, node, name='test/Add')
-        node = activation(node, name='test/' + activation_op_name)
+        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+        node = activation(node, name=scope + delim + activation_op_name)
       update_barrier = control_flow_ops.no_op(name='update_barrier')
       with ops.control_dependencies([update_barrier]):
         array_ops.identity(node, name='control_dependency')
@@ -235,7 +262,8 @@
         self._TestQuantize_DepthwiseConv2dWithoutBatchNorm)
 
   def _TestQuantize_DepthwiseConv2dWithoutBatchNorm(
-      self, activation, activation_op_name, with_bypass, delay, use_resource):
+      self, activation, activation_op_name, with_bypass, delay, use_resource,
+      scope):
     """Tests quantization: inputs -> DWConv2d no batch norm -> Activation.
 
     Args:
@@ -246,6 +274,7 @@
         inputs to just before Activation.
       delay: Int (optional), delay in number of steps until quantization starts.
       use_resource: Bool, when true uses resource variables.
+      scope: String, specifies top level scope for the graph
     """
     graph = ops.Graph()
     with graph.as_default():
@@ -254,7 +283,10 @@
       inputs = array_ops.zeros((batch_size, height, width, depth))
       stride = 1 if with_bypass else 2
       activation_fn = None if with_bypass else activation
-      scope = 'test/test2' if with_bypass else 'test'
+      conv_scope = self._GetConvScope(scope, with_bypass)
+      scope = '' if scope is None else scope
+      delim = '/' if scope else ''
+
       node = separable_conv2d(
           inputs,
           None, [5, 5],
@@ -263,10 +295,10 @@
           padding='SAME',
           weights_initializer=self._WeightInit(0.09),
           activation_fn=activation_fn,
-          scope=scope)
+          scope=conv_scope)
       if with_bypass:
-        node = math_ops.add(inputs, node, name='test/Add')
-        node = activation(node, name='test/' + activation_op_name)
+        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+        node = activation(node, name=scope + delim + activation_op_name)
       update_barrier = control_flow_ops.no_op(name='update_barrier')
       with ops.control_dependencies([update_barrier]):
         array_ops.identity(node, name='control_dependency')
@@ -280,8 +312,9 @@
     self._RunWithoutBatchNormTestOverParameters(
         self._TestQuantize_AtrousConvWithoutBatchNorm)
 
-  def _TestQuantize_AtrousConvWithoutBatchNorm(
-      self, activation, activation_op_name, with_bypass, delay, use_resource):
+  def _TestQuantize_AtrousConvWithoutBatchNorm(self, activation,
+                                               activation_op_name, with_bypass,
+                                               delay, use_resource, scope):
     """Tests quantization: inputs -> atrous conv no batch norm -> Activation.
 
     Args:
@@ -292,6 +325,7 @@
         inputs to just before Activation.
       delay: Int (optional), delay in number of steps until quantization starts.
       use_resource: Bool, when true uses resource variables.
+      scope: String, specifies top level scope for the graph
     """
     graph = ops.Graph()
     with graph.as_default():
@@ -300,7 +334,10 @@
       inputs = array_ops.zeros((batch_size, height, width, depth))
       dilation_rate = 2
       activation_fn = None if with_bypass else activation
-      scope = 'test/test2' if with_bypass else 'test'
+      conv_scope = self._GetConvScope(scope, with_bypass)
+      scope = '' if scope is None else scope
+      delim = '/' if scope else ''
+
       node = separable_conv2d(
           inputs,
           None, [3, 3],
@@ -309,10 +346,10 @@
           padding='SAME',
           weights_initializer=self._WeightInit(0.09),
           activation_fn=activation_fn,
-          scope=scope)
+          scope=conv_scope)
       if with_bypass:
-        node = math_ops.add(inputs, node, name='test/Add')
-        node = activation(node, name='test/' + activation_op_name)
+        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+        node = activation(node, name=scope + delim + activation_op_name)
       update_barrier = control_flow_ops.no_op(name='update_barrier')
       with ops.control_dependencies([update_barrier]):
         array_ops.identity(node, name='control_dependency')
@@ -353,78 +390,96 @@
     ]
     for params in parameters_list:
       # Test everything with resource variables and normal variables.
-      test_fn(params[0], params[1], params[2], params[3], params[4], False)
-      test_fn(params[0], params[1], params[2], params[3], params[4], True)
+      test_fn(params[0], params[1], params[2], params[3], params[4], False,
+              None)
+      test_fn(params[0], params[1], params[2], params[3], params[4], True, None)
+      test_fn(params[0], params[1], params[2], params[3], params[4], False,
+              'test')
+      test_fn(params[0], params[1], params[2], params[3], params[4], True,
+              'test')
 
   def _AssertCorrectQuantizedGraphWithBatchNorm(self, graph, scope, layer,
                                                 activation_op_name, with_bypass,
                                                 delay, use_resource):
     quantization_node_name = 'FakeQuantWithMinMaxVars'
+    conv_scope = self._GetConvScope(scope, with_bypass)
+    delim = '/' if conv_scope else ''
+
+    if scope:
+      scope = scope + '/'
+
     weights_quant = graph.get_operation_by_name(
-        scope + '/weights_quant/' + quantization_node_name)
+        conv_scope + delim + 'weights_quant/' + quantization_node_name)
+
     self.assertEqual(weights_quant.type, quantization_node_name)
     if use_resource:
       expected_inputs = [
-          scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
-          scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+          conv_scope + delim +
+          'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+          conv_scope + delim +
+          'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
       ]
     else:
       expected_inputs = [
-          scope + '/weights_quant/' + 'AssignMinLast',
-          scope + '/weights_quant/' + 'AssignMaxLast'
+          conv_scope + delim + 'weights_quant/' + 'AssignMinLast',
+          conv_scope + delim + 'weights_quant/' + 'AssignMaxLast'
       ]
-    expected_inputs.append(scope + '/mul_fold')
+    expected_inputs.append(conv_scope + delim + 'mul_fold')
 
     self._AssertInputOpsAre(weights_quant, expected_inputs)
     if layer == 'DepthwiseConv2dNative':
-      output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
-                                if delay else '/depthwise_Fold')
+      output_op_name = conv_scope + delim + (
+          'weights_quant/delayed_quant/Switch_1' if delay else 'depthwise_Fold')
     else:
-      output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
-                                if delay else '/' + layer + '_Fold')
+      output_op_name = conv_scope + delim + (
+          'weights_quant/delayed_quant/Switch_1' if delay else layer + '_Fold')
     self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
 
     if with_bypass:
       conv_quant = graph.get_operation_by_name(
-          scope + '/conv_quant/' + quantization_node_name)
+          conv_scope + delim + 'conv_quant/' + quantization_node_name)
       self.assertEqual(conv_quant.type, quantization_node_name)
 
       if use_resource:
         expected_inputs = [
-            scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
-            scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+            conv_scope + delim +
+            'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+            conv_scope + delim +
+            'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
         ]
       else:
         expected_inputs = [
-            scope + '/conv_quant/AssignMinEma',
-            scope + '/conv_quant/AssignMaxEma',
+            conv_scope + delim + 'conv_quant/AssignMinEma',
+            conv_scope + delim + 'conv_quant/AssignMaxEma',
         ]
-      expected_inputs.append(scope + '/add_fold')
+      expected_inputs.append(conv_scope + delim + 'add_fold')
 
       self._AssertInputOpsAre(conv_quant, expected_inputs)
       output_op_name = (
-          scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add')
+          conv_scope + delim + 'conv_quant/delayed_quant/Switch_1'
+          if delay else scope + 'Add')
       self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
 
-    act_quant = graph.get_operation_by_name(
-        'test/act_quant/' + quantization_node_name)
+    act_quant = graph.get_operation_by_name(scope + 'act_quant/' +
+                                            quantization_node_name)
     self.assertEqual(act_quant.type, quantization_node_name)
 
     if use_resource:
       expected_inputs = [
-          'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
-          'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+          scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+          scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
       ]
     else:
       expected_inputs = [
-          'test/act_quant/AssignMinEma',
-          'test/act_quant/AssignMaxEma',
+          scope + 'act_quant/AssignMinEma',
+          scope + 'act_quant/AssignMaxEma',
       ]
-    expected_inputs.append('test/' + activation_op_name)
+    expected_inputs.append(scope + activation_op_name)
 
     self._AssertInputOpsAre(act_quant, expected_inputs)
-    output_op_name = ('test/act_quant/delayed_quant/Switch_1'
-                      if delay else 'control_dependency')
+    output_op_name = (
+        scope + 'act_quant/delayed_quant/Switch_1'
+        if delay else 'control_dependency')
     self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
     self._AssertIdempotent(graph)
 
@@ -433,7 +488,7 @@
 
   def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
                                         with_bypass, delay, fused_batch_norm,
-                                        use_resource):
+                                        use_resource, scope):
     """Tests quantization: inputs -> Conv2d with batch norm -> Activation.
 
     Args:
@@ -445,6 +500,7 @@
       delay: Int (optional), delay in number of steps until quantization starts.
       fused_batch_norm: Bool, when true use FusedBatchNorm.
       use_resource: Bool, when true uses resource variables.
+      scope: String, specifies top level scope for the graph
     """
     graph = ops.Graph()
     with graph.as_default():
@@ -453,7 +509,9 @@
       inputs = array_ops.zeros((batch_size, height, width, depth))
       stride = 1 if with_bypass else 2
       out_depth = 3 if with_bypass else 32
-      scope = 'test/test2' if with_bypass else 'test'
+      conv_scope = self._GetConvScope(scope, with_bypass)
+      scope = '' if scope is None else scope
+      delim = '/' if scope else ''
       node = conv2d(
           inputs,
           out_depth, [5, 5],
@@ -463,13 +521,13 @@
           activation_fn=None,
           normalizer_fn=batch_norm,
           normalizer_params=self._BatchNormParams(fused_batch_norm),
-          scope=scope)
+          scope=conv_scope)
 
       # Manually add a bypass (optional) and an activation.
       if with_bypass:
-        node = math_ops.add(inputs, node, name='test/Add')
+        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
 
-      node = activation(node, name='test/' + activation_op_name)
+      node = activation(node, name=scope + delim + activation_op_name)
 
       update_barrier = control_flow_ops.no_op(name='update_barrier')
       with ops.control_dependencies([update_barrier]):
@@ -487,7 +545,7 @@
 
   def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
                                     with_bypass, delay, fused_batch_norm,
-                                    use_resource):
+                                    use_resource, scope):
     """Tests quantization: inputs -> FC with batch norm -> Activation.
 
     Args:
@@ -499,6 +557,7 @@
       delay: Int (optional), delay in number of steps until quantization starts.
       fused_batch_norm: Bool, when true use FusedBatchNorm.
       use_resource: Bool, when true uses resource variables.
+      scope: String, specifies top level scope for the graph
     """
     graph = ops.Graph()
     with graph.as_default():
@@ -506,7 +565,9 @@
       batch_size, depth = 5, 256
       inputs = array_ops.zeros((batch_size, depth))
       out_depth = 256 if with_bypass else 128
-      scope = 'test/test2' if with_bypass else 'test'
+      conv_scope = self._GetConvScope(scope, with_bypass)
+      scope = '' if scope is None else scope
+      delim = '/' if scope else ''
       node = fully_connected(
           inputs,
           out_depth,
@@ -514,13 +575,13 @@
           activation_fn=None,
           normalizer_fn=batch_norm,
           normalizer_params=self._BatchNormParams(fused_batch_norm),
-          scope=scope)
+          scope=conv_scope)
 
       # Manually add a bypass (optional) and an activation.
       if with_bypass:
-        node = math_ops.add(inputs, node, name='test/Add')
+        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
 
-      node = activation(node, name='test/' + activation_op_name)
+      node = activation(node, name=scope + delim + activation_op_name)
 
       update_barrier = control_flow_ops.no_op(name='update_barrier')
       with ops.control_dependencies([update_barrier]):
@@ -540,7 +601,7 @@
 
   def _TestQuantize_DepthwiseConv2dWithBatchNorm(
       self, activation, activation_op_name, with_bypass, delay,
-      fused_batch_norm, use_resource):
+      fused_batch_norm, use_resource, scope):
     """Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
 
     Args:
@@ -552,6 +613,7 @@
       delay: Int (optional), delay in number of steps until quantization starts.
       fused_batch_norm: Bool, when true use FusedBatchNorm.
       use_resource: Bool, when true uses resource variables.
+      scope: String, specifies top level scope for the graph
     """
     graph = ops.Graph()
     with graph.as_default():
@@ -559,7 +621,9 @@
       batch_size, height, width, depth = 5, 128, 128, 3
       inputs = array_ops.zeros((batch_size, height, width, depth))
       stride = 1 if with_bypass else 2
-      scope = 'test/test2' if with_bypass else 'test'
+      conv_scope = self._GetConvScope(scope, with_bypass)
+      scope = '' if scope is None else scope
+      delim = '/' if scope else ''
       node = separable_conv2d(
           inputs,
           None, [5, 5],
@@ -570,13 +634,13 @@
           activation_fn=None,
           normalizer_fn=batch_norm,
           normalizer_params=self._BatchNormParams(fused_batch_norm),
-          scope=scope)
+          scope=conv_scope)
 
       # Manually add a bypass (optional) and an activation.
       if with_bypass:
-        node = math_ops.add(inputs, node, name='test/Add')
+        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
 
-      node = activation(node, name='test/' + activation_op_name)
+      node = activation(node, name=scope + delim + activation_op_name)
 
       update_barrier = control_flow_ops.no_op(name='update_barrier')
       with ops.control_dependencies([update_barrier]):
@@ -595,7 +659,7 @@
 
   def _TestQuantize_AtrousConvWithBatchNorm(
       self, activation, activation_op_name, with_bypass, delay,
-      fused_batch_norm, use_resource):
+      fused_batch_norm, use_resource, scope):
     """Tests quantization: inputs -> atrous conv with batch norm -> Activation.
 
     Args:
@@ -607,6 +671,7 @@
       delay: Int (optional), delay in number of steps until quantization starts.
       fused_batch_norm: Bool, when true use FusedBatchNorm.
       use_resource: Bool, when true uses resource variables.
+      scope: String, specifies top level scope for the graph
     """
     graph = ops.Graph()
     with graph.as_default():
@@ -614,7 +679,10 @@
       batch_size, height, width, depth = 5, 128, 128, 3
       inputs = array_ops.zeros((batch_size, height, width, depth))
       dilation_rate = 2
-      scope = 'test/test2' if with_bypass else 'test'
+      conv_scope = self._GetConvScope(scope, with_bypass)
+      scope = '' if scope is None else scope
+      delim = '/' if scope else ''
+
       node = separable_conv2d(
           inputs,
           None, [3, 3],
@@ -625,13 +693,13 @@
           activation_fn=None,
           normalizer_fn=batch_norm,
           normalizer_params=self._BatchNormParams(fused_batch_norm),
-          scope=scope)
+          scope=conv_scope)
 
       # Manually add a bypass (optional) and an activation.
       if with_bypass:
-        node = math_ops.add(inputs, node, name='test/Add')
+        node = math_ops.add(inputs, node, name=scope + delim + 'Add')
 
-      node = activation(node, name='test/' + activation_op_name)
+      node = activation(node, name=scope + delim + activation_op_name)
 
       update_barrier = control_flow_ops.no_op(name='update_barrier')
       with ops.control_dependencies([update_barrier]):
@@ -718,6 +786,18 @@
     with open('/tmp/bn_quant_test.pbtxt', 'w') as f:
       f.write(str(graph.as_graph_def()))
 
+  def _GetConvScope(self, scope, with_bypass):
+    if scope is None:
+      scope = ''
+    delim = '/' if scope else ''
+
+    if with_bypass:
+      conv_scope = scope + delim + 'test2'
+    else:
+      conv_scope = scope
+
+    return conv_scope
+
   def _BatchNormParams(self, fused=False, force_updates=False):
     params = {
         'center': True,
diff --git a/tensorflow/contrib/rate/rate_test.py b/tensorflow/contrib/rate/rate_test.py
index 0890810..3dee163 100644
--- a/tensorflow/contrib/rate/rate_test.py
+++ b/tensorflow/contrib/rate/rate_test.py
@@ -46,7 +46,7 @@
 
   @test_util.run_in_graph_and_eager_modes()
   def testBasic(self):
-    with self.test_session():
+    with self.cached_session():
       r_ = rate.Rate()
       a = r_(array_ops.ones([1]), denominator=array_ops.ones([1]))
       self.evaluate(variables.global_variables_initializer())
@@ -67,7 +67,7 @@
 
   @test_util.run_in_graph_and_eager_modes()
   def testWhileLoop(self):
-    with self.test_session():
+    with self.cached_session():
       r_ = rate.Rate()
 
       def body(value, denom, i, ret_rate):
diff --git a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
index 6253f96..e30e725 100644
--- a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
+++ b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
@@ -210,7 +210,7 @@
 
     # Input data shape is not defined over a 2D grid, i.e. its shape is not like
     # (batch_size, data_height, data_width, data_channels).
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       data_shape = (batch_size, data_height, data_width, data_depth,
                     data_channels)
       data = np.zeros(data_shape)
@@ -225,7 +225,7 @@
         sess.run(outputs)
 
     # Warp tensor must be at least a matrix, with shape [batch_size, 2].
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       data_shape = (batch_size, data_height, data_width, data_channels)
       data = np.zeros(data_shape)
       warp_shape = (batch_size,)
@@ -238,7 +238,7 @@
         sess.run(outputs)
 
     # The batch size of the data and warp tensors must be the same.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       data_shape = (batch_size, data_height, data_width, data_channels)
       data = np.zeros(data_shape)
       warp_shape = (batch_size+1, warp_height, warp_width, 2)
@@ -252,7 +252,7 @@
 
     # The warp tensor must contain 2D coordinates, i.e. its shape last dimension
     # must be 2.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       data_shape = (batch_size, data_height, data_width, data_channels)
       data = np.zeros(data_shape)
       warp_shape = (batch_size, warp_height, warp_width, 3)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
index 1c23c28..0d61592 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
@@ -49,7 +49,7 @@
     return rpc_op.try_rpc(*args, protocol=self._protocol, **kwargs)
 
   def testScalarHostPortRpc(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       request_tensors = (
           test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
       response_tensors = self.rpc(
@@ -63,7 +63,7 @@
     self.assertAllEqual([2, 3, 4], response_message.values)
 
   def testScalarHostPortTryRpc(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       request_tensors = (
           test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
       response_tensors, status_code, status_message = self.try_rpc(
@@ -83,7 +83,7 @@
     self.assertEqual(b'', status_message_values)
 
   def testEmptyHostPortRpc(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       request_tensors = []
       response_tensors = self.rpc(
           method=self.get_method_name('Increment'),
@@ -98,7 +98,7 @@
         '/InvalidService.Increment',
         self.get_method_name('InvalidMethodName')
     ]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         with self.assertRaisesOpError(self.invalid_method_string):
           sess.run(self.rpc(method=method, address=self._address, request=''))
 
@@ -111,7 +111,7 @@
   def testInvalidAddress(self):
     # This covers the case of address='' and address='localhost:293874293874'
     address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaises(errors.UnavailableError):
         sess.run(
             self.rpc(
@@ -128,7 +128,7 @@
           self.connect_failed_string in status_message_value.decode('ascii'))
 
   def testAlwaysFailingMethod(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       response_tensors = self.rpc(
           method=self.get_method_name('AlwaysFailWithInvalidArgument'),
           address=self._address,
@@ -150,7 +150,7 @@
       self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii'))
 
   def testSometimesFailingMethodWithManyRequests(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Fail hard by default.
       response_tensors = self.rpc(
           method=self.get_method_name('SometimesFailWithInvalidArgument'),
@@ -179,7 +179,7 @@
       self.assertAllEqual(expected_message_values, status_message_values)
 
   def testVecHostPortRpc(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       request_tensors = [
           test_example_pb2.TestCase(
               values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
@@ -197,7 +197,7 @@
       self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
 
   def testVecHostPortManyParallelRpcs(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       request_tensors = [
           test_example_pb2.TestCase(
               values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
@@ -219,7 +219,7 @@
         self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
 
   def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       request_tensors = encode_proto_op.encode_proto(
           message_type='tensorflow.contrib.rpc.TestCase',
           field_names=['values'],
@@ -241,7 +241,7 @@
                          for i in range(20)], response_shape_values)
 
   def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       request_tensors = [''] * 25  # This will launch 25 RPC requests.
       response_tensors = self.rpc(
           method=self.get_method_name('SleepForever'),
@@ -254,7 +254,7 @@
           sess.run(response_tensors, options=options)
 
   def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       request_tensors = [''] * 25  # This will launch 25 RPC requests.
       response_tensors = self.rpc(
           method=self.get_method_name('SleepForever'),
@@ -265,7 +265,7 @@
         sess.run(response_tensors)
 
   def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       response_tensors, status_code, status_message = self.try_rpc(
           method=self.get_method_name('SometimesSleepForever'),
           timeout_in_ms=1000,
@@ -281,7 +281,7 @@
 
   def testTryRpcWithMultipleAddressesSingleRequest(self):
     flatten = lambda x: list(itertools.chain.from_iterable(x))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       addresses = flatten([[
           self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
       ] for _ in range(10)])
@@ -301,7 +301,7 @@
 
   def testTryRpcWithMultipleMethodsSingleRequest(self):
     flatten = lambda x: list(itertools.chain.from_iterable(x))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       methods = flatten(
           [[self.get_method_name('Increment'), 'InvalidMethodName']
            for _ in range(10)])
@@ -319,7 +319,7 @@
 
   def testTryRpcWithMultipleAddressesAndRequests(self):
     flatten = lambda x: list(itertools.chain.from_iterable(x))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       addresses = flatten([[
           self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
       ] for _ in range(10)])
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index f2c43f3..1f3b533 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -919,31 +919,28 @@
         wrapper.BahdanauAttention, wrapper.LuongAttention)
 
     expected_final_output = BasicDecoderOutput(
-        rnn_output=ResultSummary(shape=(5, 3, 20),
-                                 dtype=dtype('float32'),
-                                 mean=0.11723966),
-        sample_id=ResultSummary(shape=(5, 3),
-                                dtype=dtype('int32'),
-                                mean=9.2666666666666675))
+        rnn_output=ResultSummary(
+            shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11723966),
+        sample_id=ResultSummary(
+            shape=(5, 3), dtype=dtype('int32'), mean=7.266666666666667))
     expected_final_state = AttentionWrapperState(
         cell_state=LSTMStateTuple(
-            c=ResultSummary(shape=(5, 9),
-                            dtype=dtype('float32'),
-                            mean=-0.003545674),
-            h=ResultSummary(shape=(5, 9),
-                            dtype=dtype('float32'),
-                            mean=-0.0018327223)),
-        attention=ResultSummary(shape=(5, 20),
-                                dtype=dtype('float32'),
-                                mean=0.11728073),
+            c=ResultSummary(
+                shape=(5, 9), dtype=dtype('float32'), mean=-0.003545674),
+            h=ResultSummary(
+                shape=(5, 9), dtype=dtype('float32'), mean=-0.0018327223)),
+        attention=ResultSummary(
+            shape=(5, 20), dtype=dtype('float32'), mean=0.11601614207),
         time=3,
-        alignments=(
-            ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
-            ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+        alignments=(ResultSummary(
+            shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+                    ResultSummary(
+                        shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
         alignment_history=(),
-        attention_state=(
-            ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
-            ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)))
+        attention_state=(ResultSummary(
+            shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+                         ResultSummary(
+                             shape=(5, 8), dtype=dtype('float32'), mean=0.125)))
     expected_final_alignment_history = (
         ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),
         ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125))
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc
index 4fc36d8..c669ced 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim.cc
@@ -355,11 +355,15 @@
     const SessionOptions& session_options, const RunOptions& run_options,
     const string& export_dir,
     const std::unordered_set<string>& saved_model_tags,
-    SavedModelBundle* saved_model_bundle) {
+    SavedModelBundle* saved_model_bundle, bool* is_session_bundle) {
+  if (is_session_bundle != nullptr) {
+    *is_session_bundle = false;
+  }
   if (MaybeSavedModelDirectory(export_dir)) {
     LOG(INFO)
         << "Attempting to load native SavedModelBundle in bundle-shim from: "
         << export_dir;
+
     return LoadSavedModel(session_options, run_options, export_dir,
                           saved_model_tags, saved_model_bundle);
   } else if (IsPossibleExportDirectory(export_dir)) {
@@ -368,6 +372,9 @@
     LOG(INFO) << "Attempting to up-convert SessionBundle to SavedModelBundle "
                  "in bundle-shim from: "
               << export_dir;
+    if (is_session_bundle != nullptr) {
+      *is_session_bundle = true;
+    }
     return LoadSavedModelFromLegacySessionBundlePath(
         session_options, run_options, export_dir, saved_model_bundle);
   }
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.h b/tensorflow/contrib/session_bundle/bundle_shim.h
index 4628b6a..7f0f995 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.h
+++ b/tensorflow/contrib/session_bundle/bundle_shim.h
@@ -59,11 +59,13 @@
 }  // namespace internal
 
 // Loads a SavedModel from either a session-bundle path or a SavedModel bundle
-// path.
+// path. If `is_session_bundle` is not a nullptr, sets it to `true` iff
+// SavedModel was up-converted and loaded from a SessionBundle.
+// `is_session_bundle` value should not be used if error is returned.
 Status LoadSessionBundleOrSavedModelBundle(
     const SessionOptions& session_options, const RunOptions& run_options,
     const string& export_dir, const std::unordered_set<string>& tags,
-    SavedModelBundle* bundle);
+    SavedModelBundle* bundle, bool* is_session_bundle = nullptr);
 
 }  // namespace serving
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
index 9a1dd93..815beb7 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
@@ -63,12 +63,16 @@
 
 void LoadAndValidateSavedModelBundle(const string& export_dir,
                                      const std::unordered_set<string>& tags,
-                                     const string& signature_def_key) {
+                                     const string& signature_def_key,
+                                     bool expect_session_bundle) {
   SessionOptions session_options;
   RunOptions run_options;
   SavedModelBundle saved_model_bundle;
+  bool is_session_bundle = false;
   TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(
-      session_options, run_options, export_dir, tags, &saved_model_bundle));
+      session_options, run_options, export_dir, tags, &saved_model_bundle,
+      &is_session_bundle));
+  EXPECT_EQ(expect_session_bundle, is_session_bundle);
   const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
   const auto& signature_def_map = meta_graph_def.signature_def();
 
@@ -512,7 +516,8 @@
   const string session_bundle_export_dir =
       test_util::TestSrcDirPath(kSessionBundlePath);
   LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags,
-                                  kDefaultServingSignatureDefKey);
+                                  kDefaultServingSignatureDefKey,
+                                  /*expect_session_bundle=*/true);
 
   // Verify that the named signature is also present.
   SessionOptions session_options;
@@ -558,7 +563,8 @@
   const string saved_model_bundle_export_dir =
       io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath);
   LoadAndValidateSavedModelBundle(saved_model_bundle_export_dir,
-                                  {kSavedModelTagServe}, "regress_x_to_y");
+                                  {kSavedModelTagServe}, "regress_x_to_y",
+                                  /*expect_session_bundle=*/false);
 }
 
 // Checks a basic load fails with an invalid export path.
diff --git a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
index e4db5f2..e6a0b30 100644
--- a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
+++ b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
@@ -38,7 +38,7 @@
       graph_def = graph.as_graph_def()
       ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString())
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(variables.global_variables_initializer())
 
         for _ in range(20):
diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py
index ae8336d..807741e 100644
--- a/tensorflow/contrib/summary/summary_ops_graph_test.py
+++ b/tensorflow/contrib/summary/summary_ops_graph_test.py
@@ -52,7 +52,7 @@
       summary_ops.histogram('histogram', [1.0], step=1)
       summary_ops.image('image', [[[[1.0]]]], step=1)
       summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(summary_ops.summary_writer_initializer_op())
       sess.run(summary_ops.all_summary_ops())
     # The working condition of the ops is tested in the C++ test so we just
@@ -64,7 +64,7 @@
     writer = summary_ops.create_file_writer(logdir, max_queue=0)
     with writer.as_default(), summary_ops.always_record_summaries():
       summary_ops.scalar('scalar', 2.0, step=1)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(summary_ops.summary_writer_initializer_op())
       sess.run(summary_ops.all_summary_ops())
     events = summary_test_util.events_from_logdir(logdir)
@@ -77,7 +77,7 @@
     with writer.as_default(), summary_ops.always_record_summaries():
       with ops.name_scope('scope'):
         summary_ops.scalar('scalar', 2.0, step=1)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(summary_ops.summary_writer_initializer_op())
       sess.run(summary_ops.all_summary_ops())
     events = summary_test_util.events_from_logdir(logdir)
@@ -90,7 +90,7 @@
     writer = summary_ops.create_file_writer(logdir, max_queue=0)
     with writer.as_default(), summary_ops.always_record_summaries():
       summary_ops.scalar('scalar', 2.0)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       sess.run(summary_ops.summary_writer_initializer_op())
       step, _ = sess.run(
@@ -105,7 +105,7 @@
         logdir, max_queue=1, flush_millis=999999)
     with writer.as_default(), summary_ops.always_record_summaries():
       summary_ops.scalar('scalar', 2.0, step=1)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(summary_ops.summary_writer_initializer_op())
       get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
       # Note: First tf.Event is always file_version.
@@ -123,7 +123,7 @@
     with writer.as_default(), summary_ops.always_record_summaries():
       summary_ops.scalar('scalar', 2.0, step=1)
       flush_op = summary_ops.flush()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(summary_ops.summary_writer_initializer_op())
       get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
       # Note: First tf.Event is always file_version.
@@ -157,7 +157,7 @@
       with writer3.as_default():
         summary_ops.scalar('three', 3.0, step=3)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Run init ops across writers sequentially to avoid race condition.
       # TODO(nickfelt): fix race condition in resource manager lookup or create
       sess.run(writer1.init())
@@ -191,7 +191,7 @@
           logdir, max_queue=100, flush_millis=1000000)
       with writer.as_default():
         summary_ops.scalar('one', 1.0, step=1)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(summary_ops.summary_writer_initializer_op())
       get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
       self.assertEqual(1, get_total())  # file_version Event
@@ -219,7 +219,7 @@
           logdir, max_queue=100, flush_millis=1000000)
       with writer.as_default():
         summary_ops.scalar('one', 1.0, step=1)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(summary_ops.summary_writer_initializer_op())
       get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
       self.assertEqual(1, get_total())  # file_version Event
@@ -241,7 +241,7 @@
     training_util.get_or_create_global_step()
     name = 'hi'
     graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),))
-    with self.test_session():
+    with self.cached_session():
       with self.create_db_writer().as_default():
         summary_ops.initialize(graph=graph)
     six.assertCountEqual(self, [name],
@@ -249,7 +249,7 @@
 
   def testScalarSummary(self):
     """Test record_summaries_every_n_global_steps and all_summaries()."""
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       global_step = training_util.get_or_create_global_step()
       global_step.initializer.run()
       with ops.device('/cpu:0'):
@@ -280,7 +280,7 @@
 
   def testScalarSummaryNameScope(self):
     """Test record_summaries_every_n_global_steps and all_summaries()."""
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       global_step = training_util.get_or_create_global_step()
       global_step.initializer.run()
       with ops.device('/cpu:0'):
@@ -311,7 +311,7 @@
           self.assertEqual(events[1].summary.value[0].tag, 'scope/my_scalar')
 
   def testSummaryGraphModeCond(self):
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       training_util.get_or_create_global_step()
       logdir = tempfile.mkdtemp()
       with summary_ops.create_file_writer(
@@ -332,7 +332,7 @@
       self.assertEqual(events[1].summary.value[0].tag, 'cond/scalar')
 
   def testSummaryGraphModeWhile(self):
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       training_util.get_or_create_global_step()
       logdir = tempfile.mkdtemp()
       with summary_ops.create_file_writer(
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py b/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py
index aa30919..d49928e 100644
--- a/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py
+++ b/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py
@@ -32,7 +32,7 @@
                                           [0.9, 0.8, 0.2], [0.6, 0.4, 0.8]])
     targets = constant_op.constant([[0], [2], [1], [1]])
     in_top_2_op, update_op = top_2_fn(probabilities, targets)
-    with self.test_session():
+    with self.cached_session():
       # initializes internal accuracy vars
       variables.local_variables_initializer().run()
       # need to call in order to run the in_top_2_op internal operations because
@@ -49,7 +49,7 @@
                                           [0.3, 0.6, 0.9, 0.4, 0.8, 0.6]])
     targets = constant_op.constant([3, 0, 2, 5, 1])
     in_top_3_op, update_op = top_3_fn(probabilities, targets)
-    with self.test_session():
+    with self.cached_session():
       # initializes internal accuracy vars
       variables.local_variables_initializer().run()
       # need to call in order to run the in_top_3_op internal operations because
@@ -61,7 +61,7 @@
     predictions = constant_op.constant([0, 1, 3, 6, 5, 2, 7, 6, 4, 9])
     targets = constant_op.constant([0, 1, 4, 6, 5, 1, 7, 5, 4, 8])
     accuracy_op, update_op = eval_metrics._accuracy(predictions, targets)
-    with self.test_session():
+    with self.cached_session():
       variables.local_variables_initializer().run()
       # need to call in order to run the accuracy_op internal operations because
       # it is a streaming function
@@ -74,7 +74,7 @@
     targets = constant_op.constant(
         [1.0, 4.3, 2.6, 0.5, 1.1, 0.7, 5.1, 3.4, 1.8])
     r2_op, update_op = eval_metrics._r2(scores, targets)
-    with self.test_session():
+    with self.cached_session():
       # initializes internal accuracy vars
       variables.local_variables_initializer().run()
       # need to call in order to run the r2_op internal operations because
diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc b/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc
index cefcc96..dd5d028 100644
--- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc
@@ -67,11 +67,11 @@
     const Eigen::Tensor<float, 1, Eigen::RowMajor>& splits,
     const Eigen::Tensor<float, 1, Eigen::RowMajor>& rights, int32 num_classes,
     int i) {
-  Eigen::array<int, 1> offsets;
+  Eigen::array<Eigen::Index, 1> offsets;
   // Class counts are stored with the total in [0], so the length of each
   // count vector is num_classes + 1.
   offsets[0] = i * (num_classes + 1) + 1;
-  Eigen::array<int, 1> extents;
+  Eigen::array<Eigen::Index, 1> extents;
   extents[0] = num_classes;
   return WeightedGiniImpurity(splits.slice(offsets, extents)) +
          WeightedGiniImpurity(rights.slice(offsets, extents));
@@ -97,7 +97,7 @@
   // arguments to ClassificationSplitScore.
   const Eigen::Tensor<float, 1, Eigen::RowMajor> splits =
       split_counts.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
-  Eigen::array<int, 1> bcast;
+  Eigen::array<Eigen::Index, 1> bcast;
   bcast[0] = num_splits;
   const Eigen::Tensor<float, 1, Eigen::RowMajor> rights =
       tc.broadcast(bcast) - splits;
@@ -130,8 +130,8 @@
     const Eigen::Tensor<float, 1, Eigen::RowMajor>& right_sums,
     const Eigen::Tensor<float, 1, Eigen::RowMajor>& right_squares,
     int32 accumulator, int32 num_regression_dims, int i) {
-  Eigen::array<int, 1> offsets = {i * num_regression_dims + 1};
-  Eigen::array<int, 1> extents = {num_regression_dims - 1};
+  Eigen::array<Eigen::Index, 1> offsets = {i * num_regression_dims + 1};
+  Eigen::array<Eigen::Index, 1> extents = {num_regression_dims - 1};
   float left_count = splits_count_accessor(accumulator, i, 0);
   float right_count = totals_count_accessor(accumulator, 0) - left_count;
 
@@ -178,7 +178,7 @@
   const auto splits_count_accessor = split_sums.tensor<float, 3>();
   const auto totals_count_accessor = total_sums.tensor<float, 2>();
 
-  Eigen::array<int, 1> bcast;
+  Eigen::array<Eigen::Index, 1> bcast;
   bcast[0] = num_splits;
   const auto right_sums = tc_sum.broadcast(bcast) - splits_sum;
   const auto right_squares = tc_square.broadcast(bcast) - splits_square;
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
index e429d12..1c4e18d 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
@@ -32,7 +32,7 @@
     indices = [[1], [10]]
     updates = [100., 200.]
 
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
       self.assertAllEqual(
@@ -45,7 +45,7 @@
     indices = [[0, 0, 1], [1, 1, 2]]
     updates = [100., 200.]
 
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
       self.assertAllEqual([[[1., 102., 3.], [4., 5., 6.]],
@@ -57,7 +57,7 @@
     indices = []
     updates = []
 
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
       self.assertAllEqual(init_val, input_data.eval())
@@ -67,7 +67,7 @@
     input_data = variables.Variable(init_val)
     indices = [[0, 0, 1], [1, 1, 2]]
     updates = [100.]
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       with self.assertRaisesOpError(
           'Number of updates should be same as number of indices.'):
@@ -80,7 +80,7 @@
     indices = [[0, 0], [1, 1]]
     updates = [[100., 200., 300.], [400., 500., 600.]]
 
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
       self.assertAllEqual([[[101., 202., 303.], [4., 5., 6.]],
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index 1c9c818..e0f0c0d 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -149,7 +149,7 @@
     self.assertTrue(isinstance(probs, ops.Tensor))
     self.assertTrue(isinstance(paths, ops.Tensor))
     self.assertTrue(isinstance(var, ops.Tensor))
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       resources.initialize_resources(resources.shared_resources()).run()
       self.assertEquals(probs.eval().shape, (4, 2))
diff --git a/tensorflow/contrib/tensorboard/db/loader.cc b/tensorflow/contrib/tensorboard/db/loader.cc
index 4d7337a..6439328 100644
--- a/tensorflow/contrib/tensorboard/db/loader.cc
+++ b/tensorflow/contrib/tensorboard/db/loader.cc
@@ -111,10 +111,10 @@
     ++records;
   }
   uint64 elapsed = env->NowMicros() - start;
+  uint64 bps = (elapsed == 0 ? offset : static_cast<uint64>(
+                                            offset / (elapsed / 1000000.0)));
   LOG(INFO) << "Loaded " << AddCommas(offset) << " bytes with "
-            << AddCommas(records) << " records at "
-            << AddCommas(offset / (elapsed / 1000000)) << " bps";
-
+            << AddCommas(records) << " records at " << AddCommas(bps) << " bps";
   return 0;
 }
 
diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md
index 687dee0..caf8b6d 100644
--- a/tensorflow/contrib/tensorrt/README.md
+++ b/tensorflow/contrib/tensorrt/README.md
@@ -26,4 +26,4 @@
 In order to make use of TensorRT integration, you will need a local installation
 of TensorRT 3.0.4 from the [NVIDIA Developer website](https://developer.nvidia.com/tensorrt).
 Installation instructions for compatibility with TensorFlow are provided on the
-[TensorFlow Installation page](https://www.tensorflow.org/install/install_linux#nvidia_requirements_to_run_tensorflow_with_gpu_support).
+[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide.
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index b019c99..7ad9bf2 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -678,7 +678,7 @@
 // Function to construct a funcdef from the segment and add it to the graph.
 tensorflow::Status RegisterSegmentFunctionToFunctionLibrary(
     tensorflow::Graph* graph, const tensorflow::GraphDef& segment,
-    const string& name) {
+    const string& engine_name) {
   tensorflow::Graph sgraph(graph->flib_def());
   tensorflow::GraphConstructorOptions gcopts;
   TF_RETURN_IF_ERROR(
@@ -761,9 +761,9 @@
   tensorflow::FunctionDefLibrary fdeflib;
   auto native_segment = fdeflib.add_function();
   TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef(
-      sgraph, StrCat(name, "_native_segment"), native_segment));
+      sgraph, StrCat(engine_name, "_native_segment"), native_segment));
   if (VLOG_IS_ON(7)) {
-    VLOG(7) << name << " Function_Def ";
+    VLOG(7) << engine_name << " Function_Def ";
     VLOG(7) << native_segment->DebugString();
   }
   VLOG(1) << "Adding funcdef to graphlib";
@@ -780,12 +780,12 @@
     // If device is not set, use the first found GPU device for the conversion.
     for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
       TfGpuId tf_gpu_id(tf_gpu_id_value);
-      CudaGpuId cuda_gpu_id;
-      Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+      PlatformGpuId platform_gpu_id;
+      Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
       if (s.ok()) {
         VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
-                << cuda_gpu_id.value();
-        cuda_device_id = cuda_gpu_id.value();
+                << platform_gpu_id.value();
+        cuda_device_id = platform_gpu_id.value();
         GPUOptions gpu_options;
         // If the TF to Cuda gpu id mapping exist, the device and corresponding
         // allocator must have been initialized already, so the
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index c98b07a..0ce8917 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -693,8 +693,15 @@
       // TODO(jie): tf protobuf seems to be omitting the :0 suffix
       string output_name = node_def.name();
       if (i != 0) output_name = StrCat(output_name, ":", i);
+      // We need to check the name before setting it. For Identity op where the
+      // output is the input, if its input is one of the engine input, setting
+      // the name here will overwrite engine input bindings which will cause
+      // runtime error.
       if (output.is_tensor()) {
-        output.tensor()->setName(output_name.c_str());
+        const char* tensor_name = output.tensor()->getName();
+        if (tensor_name == nullptr || std::strlen(tensor_name) == 0) {
+          output.tensor()->setName(output_name.c_str());
+        }
       }
       VLOG(2) << "Adding out tensor " << output_name << ": "
               << output.DebugString();
@@ -779,12 +786,11 @@
       // skip control nodes
       if (input_name[0] == '^') continue;
       string name = input_name;
-      auto first = name.find_first_of(':');
-      // TODO(aaroey): why removing the colon but not the zero? A bug?
+      auto last = name.find_last_of(':');
       // TODO(aaroey): use TensorId
-      if (first != string::npos && first + 2 == name.size() &&
-          name[first + 1] == '0') {
-        name.erase(first);
+      if (last != string::npos && last + 2 == name.size() &&
+          name[last + 1] == '0') {
+        name.erase(last);
       }
 
       if (trt_tensors_.count(name)) {
@@ -2697,7 +2703,6 @@
   TrtUniquePtrType<nvinfer1::IBuilder> builder(
       nvinfer1::createInferBuilder(*logger));
   builder->setMaxBatchSize(max_batch_size);
-  // TODO(aaroey): use the allocator to allocate the TRT workspace.
   builder->setMaxWorkspaceSize(max_workspace_size_bytes);
 #if NV_TENSORRT_MAJOR > 3
   builder->setGpuAllocator(allocator);
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 2b42d81..88cf8d5 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -565,21 +565,22 @@
       new TRTInt8Calibrator(device_buffers_, batch_size, name()));
   const string label(name());
   auto segment_graph = &segment_graph_;
-  const int cuda_gpu_id = ctx->device()->tensorflow_gpu_device_info()->gpu_id;
-  if (cuda_gpu_id < 0) {
+  const int platform_gpu_id =
+      ctx->device()->tensorflow_gpu_device_info()->gpu_id;
+  if (platform_gpu_id < 0) {
     LOG(ERROR) << "Can't get gpu_device_info from context->device()";
     return tensorflow::errors::InvalidArgument(
         "Context->device doesn't contain device info!");
   }
   const int64 workspace_size_bytes = workspace_size_;
   cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes,
-                                    cuda_gpu_id, workspace_size_bytes]() {
-    VLOG(0) << "Starting calibration thread on device " << cuda_gpu_id
+                                    platform_gpu_id, workspace_size_bytes]() {
+    VLOG(0) << "Starting calibration thread on device " << platform_gpu_id
             << ", Calibration Resource @ " << cres;
-    auto err = cudaSetDevice(cuda_gpu_id);
+    auto err = cudaSetDevice(platform_gpu_id);
     if (err != cudaSuccess) {
       // TODO(aaroey): should return error here.
-      LOG(ERROR) << "Couldn't set cuda device to " << cuda_gpu_id
+      LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
                  << " in calibration thread";
     }
     // ConvertGraphDefToEngine() will try to build the engine. This thread
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
index d8f97bf..a942586 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
@@ -27,12 +27,16 @@
 namespace tensorrt {
 
 // std::align is not supported, so this method mimic its behavior.
-void* Align(size_t alignment, size_t size, void*& ptr, size_t& space) {
-  QCHECK_GT(alignment, 0) << "alignment must be greater than 0.";
+//
+// NOTE(aaroey): according to the TensorRT API,
+// nvinfer1::IGpuAllocator::allocate() uses uint64_t type for size and alignment
+// parameters, so here we use the same type to make it compatible.
+void* Align(uint64_t alignment, uint64_t size, void*& ptr, uint64_t& space) {
+  QCHECK_GT(alignment, 0ul) << "alignment must be greater than 0.";
   QCHECK_EQ(0, alignment & (alignment - 1)) << "Alignment must be power of 2.";
-  QCHECK_GT(size, 0) << "size must be greater than 0.";
+  QCHECK_GT(size, 0ul) << "size must be greater than 0.";
   QCHECK(ptr) << "ptr must not be nullptr.";
-  QCHECK_GT(space, 0) << "space must be greater than 0.";
+  QCHECK_GT(space, 0ul) << "space must be greater than 0.";
   const uintptr_t ptr_val = reinterpret_cast<uintptr_t>(ptr);
   QCHECK_GE(ptr_val + space, ptr_val) << "Provided space overflows.";
 
@@ -67,12 +71,16 @@
 
 void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment,
                                    uint32_t flags) {
+  if (size == 0) return nullptr;
   // WAR for allocator alignment requirement. Certain cuda API calls require GPU
   // memory with alignemtn to cudaDeviceProp::textureAlignment.
   // See issue #20856
   alignment = 512;
   assert((alignment & (alignment - 1)) == 0);  // zero or a power of 2.
-  size_t total_size = size + alignment;
+  uint64_t total_size = size + alignment;
+  // TODO(aaroey): AllocateRaw takes size_t size as input, so it'll produce
+  // unexpected result when TRT tries to allocate more bytes than size_t can
+  // carry. Fix this.
   void* mem = allocator_->AllocateRaw(alignment, total_size);
   if (!mem) return nullptr;
 
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h
index 6f94492..dc9862b 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h
@@ -29,7 +29,7 @@
 namespace tensorflow {
 namespace tensorrt {
 // std::align is not supported, so this function mimic its behavior.
-void* Align(size_t alignment, size_t size, void*& ptr, size_t& space);
+void* Align(uint64_t alignment, uint64_t size, void*& ptr, uint64_t& space);
 }  // namespace tensorrt
 }  // namespace tensorflow
 
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc
index f515ed0..ad6b1d7 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc
@@ -20,11 +20,11 @@
 namespace tensorflow {
 namespace tensorrt {
 
-bool RunTest(const size_t alignment, const size_t size,
-             const intptr_t orig_ptr_val, const size_t orig_space) {
+bool RunTest(const uint64_t alignment, const uint64_t size,
+             const intptr_t orig_ptr_val, const uint64_t orig_space) {
   void* const orig_ptr = reinterpret_cast<void*>(orig_ptr_val);
   void* ptr = orig_ptr;
-  size_t space = orig_space;
+  uint64_t space = orig_space;
   void* result = Align(alignment, size, ptr, space);
   if (result == nullptr) {
     EXPECT_EQ(orig_ptr, ptr);
@@ -43,24 +43,25 @@
 }
 
 TEST(TRTAllocatorTest, Align) {
-  for (const size_t space :
-       {1, 2, 3, 4, 7, 8, 9, 10, 16, 32, 511, 512, 513, 700, 12345}) {
-    for (size_t alignment = 1; alignment <= space * 4; alignment *= 2) {
-      for (const intptr_t ptr_val :
+  for (const uint64_t space :
+       {1ul, 2ul, 3ul, 4ul, 7ul, 8ul, 9ul, 10ul, 16ul, 32ul, 511ul, 512ul,
+        513ul, 700ul, 12345ul, 1ul << 32}) {
+    for (uint64_t alignment = 1; alignment <= space * 4; alignment *= 2) {
+      for (const uintptr_t ptr_val :
            {1ul, alignment == 1 ? 1ul : alignment - 1, alignment, alignment + 1,
             alignment + (alignment / 2)}) {
         if (ptr_val % alignment == 0) {
-          for (const size_t size :
+          for (const uint64_t size :
                {1ul, space == 1 ? 1ul : space - 1, space, space + 1}) {
             EXPECT_EQ(space >= size, RunTest(alignment, size, ptr_val, space));
           }
         } else {
           EXPECT_FALSE(RunTest(alignment, space, ptr_val, space));
-          const size_t diff = alignment - ptr_val % alignment;
+          const uint64_t diff = alignment - ptr_val % alignment;
           if (space > diff) {
             EXPECT_TRUE(
                 RunTest(alignment, space - diff, ptr_val + diff, space - diff));
-            for (const size_t size :
+            for (const uint64_t size :
                  {1ul, space - diff > 1 ? space - diff - 1 : 1ul, space - diff,
                   space - diff + 1, space - 1}) {
               EXPECT_EQ(space - diff >= size,
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
index e9ac833..7e9ffb0 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -183,6 +183,12 @@
         "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
     }
 
+  def ShouldRunTest(self, run_params):
+    """Whether to run the test."""
+    # Disable the test in fp16 mode since multiple matmul and add ops together
+    # can cause overflow.
+    return run_params.precision_mode != "FP16"
+
 
 class PartiallyConvertedTestB(PartiallyConvertedTestA):
 
diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
index 62f4e52..d2f6534 100644
--- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
@@ -144,14 +144,6 @@
     # mode, which is a bug. Re-enable this when trt library is fixed.
     return not trt_test.IsQuantizationMode(run_params.precision_mode)
 
-  def ExpectedAbsoluteTolerance(self, run_params):
-    """The absolute tolerance to compare floating point results."""
-    return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
-
-  def ExpectedRelativeTolerance(self, run_params):
-    """The relative tolerance to compare floating point results."""
-    return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
-
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index fc647e4..4f935a7 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -134,7 +134,7 @@
             dims[0] for dims in self._GetParamsCached().input_dims if len(dims)
         ]),
         max_workspace_size_bytes=1 << 25,
-        precision_mode=self._ToBytes(run_params.precision_mode),
+        precision_mode=run_params.precision_mode,
         minimum_segment_size=2,
         is_dynamic_op=run_params.dynamic_engine,
         maximum_cached_engines=1,
@@ -179,11 +179,11 @@
 
   def ExpectedAbsoluteTolerance(self, run_params):
     """The absolute tolerance to compare floating point results."""
-    return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03
+    return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
 
   def ExpectedRelativeTolerance(self, run_params):
     """The relative tolerance to compare floating point results."""
-    return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03
+    return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
 
   def _GetParamsCached(self):
     if self._trt_test_params is None:
@@ -414,6 +414,7 @@
     if not self.ShouldRunTest(run_params):
       return
     assert run_params.precision_mode in PRECISION_MODES
+    np.random.seed(12345)
 
     params = self._GetParamsCached()
     input_gdef = params.gdef
diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
index 84e3614..832d34d 100644
--- a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
+++ b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
@@ -63,7 +63,7 @@
         (b"jumps", b"brown"),
         (b"jumps", b"fox"),
     ])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_tokens, tokens.eval())
       self.assertAllEqual(expected_labels, labels.eval())
 
@@ -94,7 +94,7 @@
         (b"jumps", b"fox"),
         (b"jumps", b"jumps"),
     ])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_tokens, tokens.eval())
       self.assertAllEqual(expected_labels, labels.eval())
 
@@ -105,7 +105,7 @@
     # If emit_self_as_target is False (default), output will be empty.
     tokens, labels = text.skip_gram_sample(
         input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False)
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(0, tokens.eval().size)
       self.assertEqual(0, labels.eval().size)
 
@@ -117,7 +117,7 @@
         (b"quick", b"quick"),
         (b"brown", b"brown"),
     ])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_tokens, tokens.eval())
       self.assertAllEqual(expected_labels, labels.eval())
 
@@ -134,7 +134,7 @@
         (b"brown", b"the"),
         (b"brown", b"quick"),
     ])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_tokens, tokens.eval())
       self.assertAllEqual(expected_labels, labels.eval())
 
@@ -150,7 +150,7 @@
         (b"quick", b"brown"),
         (b"brown", b"quick"),
     ])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_tokens, tokens.eval())
       self.assertAllEqual(expected_labels, labels.eval())
 
@@ -165,7 +165,7 @@
         (b"quick", b"brown"),
         (b"brown", b"quick"),
     ])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_tokens, tokens.eval())
       self.assertAllEqual(expected_labels, labels.eval())
 
@@ -196,7 +196,7 @@
         (b"over", b"fox"),
         (b"over", b"jumps"),
     ])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens_eval, labels_eval = sess.run([tokens, labels])
       self.assertAllEqual(expected_tokens, tokens_eval)
       self.assertAllEqual(expected_labels, labels_eval)
@@ -222,7 +222,7 @@
     tokens_2, labels_2 = text.skip_gram_sample(
         input_tensor, min_skips=1, max_skips=5)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run(
           [tokens_1, labels_1, tokens_2, labels_2])
 
@@ -244,7 +244,7 @@
         (b"brown", b"fox"),
         (b"fox", b"brown"),
     ])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       coord = coordinator.Coordinator()
       threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
 
@@ -269,7 +269,7 @@
         (2, 3),
         (3, 2),
     ])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_tokens, tokens.eval())
       self.assertAllEqual(expected_labels, labels.eval())
 
@@ -286,7 +286,7 @@
     for min_skips, max_skips in invalid_skips:
       tokens, labels = text.skip_gram_sample(
           input_tensor, min_skips=min_skips, max_skips=max_skips)
-      with self.test_session() as sess, self.assertRaises(
+      with self.cached_session() as sess, self.assertRaises(
           errors.InvalidArgumentError):
         sess.run([tokens, labels])
 
@@ -338,7 +338,7 @@
     vocab_freq_table = lookup.HashTable(
         lookup.KeyValueTensorInitializer(keys, values), -1)
 
-    with self.test_session():
+    with self.cached_session():
       vocab_freq_table.init.run()
 
       # No vocab_freq_table specified - output should be the same as input.
@@ -395,7 +395,7 @@
     vocab_freq_table = lookup.HashTable(
         lookup.KeyValueTensorInitializer(keys, values), -1)
 
-    with self.test_session():
+    with self.cached_session():
       vocab_freq_table.init.run()
       output = skip_gram_ops._filter_input(
           input_tensor=input_tensor,
@@ -464,7 +464,7 @@
         (b"life", b"and"),
         (b"and", b"life"),
     ])
-    with self.test_session():
+    with self.cached_session():
       lookup_ops.tables_initializer().run()
       self.assertAllEqual(expected_tokens, tokens.eval())
       self.assertAllEqual(expected_labels, labels.eval())
@@ -510,7 +510,7 @@
         (b"to", b"life"),
         (b"life", b"to"),
     ])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       lookup_ops.tables_initializer().run()
       tokens_eval, labels_eval = sess.run([tokens, labels])
       self.assertAllEqual(expected_tokens, tokens_eval)
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 298ffc1..4e0b612 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -80,7 +80,7 @@
         "tpu_embedding_ops",
     ],
     deps = [
-        "//tensorflow/contrib/tpu/proto:tpu_embedding_config_proto_cc",
+        "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
         "//tensorflow/core:lib_proto_parsing",
         "//tensorflow/core:protos_all_cc",
     ],
@@ -99,7 +99,7 @@
         "ops/tpu_embedding_ops.cc",
     ],
     deps = [
-        "//tensorflow/contrib/tpu/proto:tpu_embedding_config_proto_cc",
+        "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
         "//tensorflow/core:lib_proto_parsing",
     ],
 )
@@ -351,7 +351,7 @@
 
 tf_py_test(
     name = "topology_test",
-    size = "small",
+    size = "medium",
     srcs = ["python/tpu/topology_test.py"],
     additional_deps = [
         ":tpu",
diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc
index 15a2bb1..285e11d 100644
--- a/tensorflow/contrib/tpu/ops/replication_ops.cc
+++ b/tensorflow/contrib/tpu/ops/replication_ops.cc
@@ -24,9 +24,11 @@
 
 REGISTER_OP("TPUReplicateMetadata")
     .Attr("num_replicas: int >= 0")
+    .Attr("num_cores_per_replica: int = 1")
     .Attr("topology: string = \"\"")
     .Attr("use_tpu: bool = true")
     .Attr("device_assignment: list(int) = []")
+    // Deprecated. Use num_cores_per_replica instead.
     .Attr("computation_shape: list(int) = []")
     .Attr("host_compute_core: list(string) = []")
     .SetShapeFn(shape_inference::UnknownShape);
@@ -93,11 +95,11 @@
 REGISTER_OP("TPUReplicate")
     .Attr("computation: func")
     .Attr("num_replicas: int >= 1")
+    .Attr("num_cores_per_replica: int = 1")
     .Attr("topology: string = \"\"")
     .Attr("use_tpu: bool = true")
     .Attr("device_assignment: list(int) = []")
     .Attr("host_compute_core: list(string) = []")
-    .Attr("computation_shape: list(int) = []")
     .Attr("Tinputs: list(type) >= 0")
     .Attr("Tbroadcast_inputs: list(type) >= 0")
     .Attr("NumVariables: int >= 0")
@@ -114,16 +116,15 @@
 
 computation: a function containing the computation to run.
 num_replicas: the number of replicas of the computation to run.
+num_cores_per_replica: the number of logical cores in each replica.
 topology: A serialized tensorflow.tpu.TopologyProto that describes the TPU
 topology.
 use_tpu: a bool indicating if this computation will run on TPU or CPU/GPU.
 Currently, only supports a default placement (computation is placed on GPU
 if one is available, and on CPU if not).
-computation_shape: a [mesh_dimension] array describing the shape of each
-  computation replica in numbers of cores in the TPU mesh.
 device_assignment: a flattened array with shape
-  [replica] + computation_shape + [mesh_dimension] that maps the coordinates of
-  logical cores in each replica of a computation to physical coordinates in
+  [replica, num_cores_per_replica, mesh_dimension] that maps the coordinates
+  of logical cores in each replica of a computation to physical coordinates in
   the TPU topology.
 Tinputs: the types of the arguments to 'computation'.
 inputs: the inputs to 'computation', flattened, in replica-major order.
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
index 72d37f7..18b9893 100644
--- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/contrib/tpu/proto/tpu_embedding_config.pb.h"
+#include "tensorflow/contrib/tpu/proto/tpu_embedding_configuration.pb.h"
 #include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/shape_inference.h"
@@ -88,12 +88,12 @@
 
   int table_id;
   TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
-  int64 num_tables = config.table_config_size();
+  int64 num_tables = config.table_descriptor_size();
   if (table_id >= num_tables) {
     return errors::InvalidArgument("Table id >= num_tables");
   }
-  int64 width = config.table_config(table_id).width();
-  int64 num_rows = config.table_config(table_id).num_rows();
+  int64 width = config.table_descriptor(table_id).dimension();
+  int64 num_rows = config.table_descriptor(table_id).vocabulary_size();
 
   TF_RETURN_IF_ERROR(c->set_output("parameters", {c->Matrix(num_rows, width)}));
   return Status::OK();
@@ -160,12 +160,12 @@
 
   int table_id;
   TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
-  int64 num_tables = config.table_config_size();
+  int64 num_tables = config.table_descriptor_size();
   if (table_id >= num_tables) {
     return errors::InvalidArgument("Table id >= num_tables");
   }
-  int64 width = config.table_config(table_id).width();
-  int64 num_rows = config.table_config(table_id).num_rows();
+  int64 width = config.table_descriptor(table_id).dimension();
+  int64 num_rows = config.table_descriptor(table_id).vocabulary_size();
 
   TF_RETURN_IF_ERROR(c->set_output("parameters", {c->Matrix(num_rows, width)}));
   TF_RETURN_IF_ERROR(
@@ -244,11 +244,11 @@
   if (!config.ParseFromString(config_string)) {
     return errors::InvalidArgument("Malformed tpu_embedding_config.");
   }
-  int64 batch_size = config.batch_size();
-  int64 num_tables = config.table_config_size();
+  int64 batch_size = config.batch_size_per_tensor_core();
+  int64 num_tables = config.table_descriptor_size();
   for (int table_id = 0; table_id < num_tables; ++table_id) {
-    int64 width = config.table_config(table_id).width();
-    int64 num_features = config.table_config(table_id).num_features();
+    int64 width = config.table_descriptor(table_id).dimension();
+    int64 num_features = config.table_descriptor(table_id).vocabulary_size();
     c->set_output(table_id, c->Matrix(batch_size * num_features, width));
   }
   return Status::OK();
diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD
index 598b73b..c20cab8 100644
--- a/tensorflow/contrib/tpu/proto/BUILD
+++ b/tensorflow/contrib/tpu/proto/BUILD
@@ -10,12 +10,15 @@
 )
 
 tf_proto_library(
-    name = "tpu_embedding_config_proto",
+    name = "tpu_embedding_configuration_proto",
     srcs = [
-        "tpu_embedding_config.proto",
+        "tpu_embedding_configuration.proto",
     ],
     cc_api_version = 2,
-    protodeps = [":optimization_parameters_proto"],
+    protodeps = [
+        ":tpu_embedding_output_layout_proto",
+        ":optimization_parameters_proto",
+    ],
     visibility = ["//visibility:public"],
 )
 
@@ -29,6 +32,15 @@
 )
 
 tf_proto_library(
+    name = "tpu_embedding_output_layout_proto",
+    srcs = [
+        "tpu_embedding_output_layout.proto",
+    ],
+    cc_api_version = 2,
+    visibility = ["//visibility:public"],
+)
+
+tf_proto_library(
     name = "topology_proto",
     srcs = [
         "topology.proto",
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
deleted file mode 100644
index 3476cc8..0000000
--- a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
+++ /dev/null
@@ -1,66 +0,0 @@
-syntax = "proto3";
-
-package tensorflow.tpu;
-
-import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";
-
-// The TPUEmbeddingConfiguration contains specification of TPU Embedding lookups
-// and gradient updates separate from the TF Graph.
-message TPUEmbeddingConfiguration {
-  // model_mode specifies whether the model is to be run in training or
-  // inference. In inference mode, gradient updates to embedding tables are not
-  // performed.
-  enum ModelMode {
-    INVALID = 0;
-    TRAINING = 1;
-    INFERENCE = 2;
-  }
-
-  ModelMode model_mode = 1;
-
-  // num_hosts is the number of host CPU systems in the training/inference job.
-  // Each embedding table must be sharded into num_hosts separate Variables,
-  // placed separately on the num_hosts CPU devices in the cluster. Sharding
-  // will be performed equivalently to the 'div' sharding_strategy option of
-  // embedding_lookup() and embedding_lookup_sparse().
-  int32 num_hosts = 2;
-
-  // The total number of TensorNodes. This is equal to num_hosts times the
-  // number of TensorNodes attached to each host.
-  int32 num_tensornodes = 3;
-
-  // The number of training examples per TensorNode.
-  int32 batch_size = 4;
-
-  // Each Embedding
-  message TPUEmbeddingTable {
-    // Name of the embedding table. This will be used to name Variables in the
-    // Tensorflow Graph.
-    string name = 1;
-
-    // Number of rows of the embedding table. The Variable created to hold the
-    // learned embedding table values will have shape (num_rows, width).
-    int32 num_rows = 3;
-
-    // Width of the embedding table. The Variable created to hold the
-    // learned embedding table values will have shape (num_rows, width).
-    int32 width = 4;
-
-    // Number of distinct embedding activation vectors per training example
-    // produced by lookups into this table during model evaluation. For each
-    // table, the Graph will receive an activations Tensor of shape
-    //   (batch_size * table.num_features, table.width).
-    // For example, num_features = 1 produces equivalent behavior to a single
-    // tf.nn.embedding_lookup() call. In the case of 'multivalent' embeddings,
-    // (i.e. tf.nn.embedding_lookup_sparse()) which compute weighted averages of
-    // embedding table rows, num_features is the number of vectors produced
-    // after averaging. In sequence models num_features is typically equal
-    // to the sequence length, since each sequence element must be represented
-    // separately to the convolutional or recurrent network.
-    int32 num_features = 5;
-
-    OptimizationParameters optimization_parameters = 6;
-  }
-
-  repeated TPUEmbeddingTable table_config = 5;
-}
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto
new file mode 100644
index 0000000..da19b13
--- /dev/null
+++ b/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto
@@ -0,0 +1,95 @@
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";
+import "tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto";
+
+message TPUEmbeddingConfiguration {
+  // Description of the various embedding tables.
+  message TableDescriptor {
+    // Name of the table.
+    string name = 1;
+    // Size of the vocabulary (i.e., number of rows) in the table.
+    int32 vocabulary_size = 2;
+    // The embedding dimension (i.e., the width of the embedding table).
+    int32 dimension = 3;
+    // Number of features mapped to this table.
+    int32 num_features = 4;
+    // Details of the learning algorithm used to update the embedding
+    // parameters.
+    OptimizationParameters optimization_parameters = 5;
+  }
+  repeated TableDescriptor table_descriptor = 1;
+
+  // Mode. Should the embedding layer program be run for inference (just forward
+  // pass), training (both forward and backward pass) or just the backward_pass.
+  enum Mode {
+    UNSPECIFIED = 0;
+    INFERENCE = 1;
+    TRAINING = 2;
+    BACKWARD_PASS_ONLY = 3;
+  }
+  Mode mode = 2;
+
+  // Number of samples in each batch of embedding layer activations sent to
+  // the TensorCore.
+  int32 batch_size_per_tensor_core = 3;
+
+  // Number of TPU hosts used for inference/training.
+  int32 num_hosts = 4;
+
+  // Number of TensorCore used for inference/training.
+  int32 num_tensor_cores = 5;
+
+  // Sharding strategy of the embedding tables among the hosts.
+  // If the sharding_strategy is "mod", each id is assigned to host
+  // "id % num_hosts". For instance, 13 ids are split across 5 hosts as:
+  // [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]].
+  // If the sharding_strategy is "div", ids are assigned to hosts in a
+  // contiguous manner. In this case, 13 ids are split across 5 hosts as:
+  // [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]].
+  // In both the strategies, if the id space does not evenly divide the number
+  // of hosts, each of the first "table_descriptor.num_ids % num_hosts" hosts
+  // will be assigned one more id.
+  // This partitioning strategy exactly follows that in the embedding_lookup
+  // TensorFlow function at tensorflow/python/ops/embedding_ops.py.
+  enum ShardingStrategy {
+    DIV_DEFAULT = 0;
+    MOD = 1;
+  }
+  ShardingStrategy sharding_strategy = 6;
+
+  // This parameter determines if the execution of the sparse core will be
+  // pipelined with that of the TensorCore. This parameter only affects results
+  // when mode=TRAINING. If mode=INFERENCE or BACKWARD_PASS_ONLY, this parameter
+  // does not affect execution and hence, is a don't care value.
+  //
+  // false: The execution of the sparse core is not pipelined with that of the
+  // TensorCore. The forward pass of every step on the sparse core is executed
+  // only after the backward pass of the previous step is complete. And the
+  // backward pass on the sparse core is executed only after the embedding
+  // gradients have been computed on the TensorCore on every step. This ensures
+  // that the activations on every step observe the gradient updates from the
+  // previous step on both the sparse core and the TensorCore.
+  //
+  // true: The execution of the sparse core is pipelined with that of the
+  // TensorCore. The forward pass of every step on the sparse core can be
+  // executed after the forward pass of the previous step is complete without
+  // waiting for the backward pass. This improves the utilization of the sparse
+  // core allowing it to process step N+1 while the embedding gradients for step
+  // N are computed on the TensorCore. The backward pass of every step on the
+  // sparse core is executed directly after the forward pass for the next step
+  // is complete. The drawback is that embedding activations for step N+1 do not
+  // observe the embedding gradient updates from step N. This could affect model
+  // quality if step N and N+1 involve the same set of embedding IDs. However,
+  // since the embedding updates are sparse, this is generally not considered a
+  // problem.
+  bool pipeline_execution_with_tensor_core = 7;
+
+  // Extended output layout information; if not provided, a compatibility mode
+  // will use defaults that match the old layout. Providing a value for this
+  // field is EXPERIMENTAL and most ways of filling it will probably break. Do
+  // not set it unless you know what you are doing.
+  TPUEmbeddingOutputLayout output_layout = 8;
+}
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto
new file mode 100644
index 0000000..aed30b2
--- /dev/null
+++ b/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto
@@ -0,0 +1,75 @@
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+// In the comments here, "layout" refers to the top-level EmbeddingOutputLayout
+// proto contained in the TPUEmbeddingConfiguration.
+
+// The embedding output consists of a list of tensors, each specified by an
+// EmbeddingOutputTensor proto within the EmbeddingOutputLayout (the "output"
+// field). Each table and feature lookup is then placed into some number of
+// particular positions within some output tensor (identified by "tensor_index"
+// within OutputLocation). The tree of table lookups, feature lookups, and
+// output locations is specified by the
+// "table(table_id).feature(feature_id).output_location" repeated fields within
+// EmbeddingOutputLayout.
+
+message TPUEmbeddingOutputLayout {
+  // Location of one copy of the feature's data.
+  message OutputLocation {
+    // Which output tensor this copy of the feature will go into. Must be
+    // between 0 and layout.output_size().
+    int32 tensor_index = 1;
+
+    // Offset in dimension 0 for this feature copy. Must be between 0 and
+    // layout.output(tensor_index).dim0_size_per_sample().
+    int32 dim0_offset = 2;
+
+    // Offset in dimension 1 for this feature copy. Must be between 0 and
+    // layout.output(tensor_index).dim1_size() - table width; repeated or
+    // partially/fully overlapping values are allowed and results in the same
+    // range will be summed (with the gradients replicated in the backward
+    // pass).
+    int32 dim1_offset = 3;
+  }
+
+  // Description of the output placement for one feature.
+  message FeatureDescriptor {
+    // Typically, only one copy of each feature is used, but multiple are
+    // allowed and the same data will be copied to all of them (with the
+    // gradients summed in the backward pass).
+    repeated OutputLocation output_location = 1;
+  }
+
+  // Description of the output placement for features of one table.
+  message TableDescriptor {
+    // Output locations for each feature loaded from this table.
+    repeated FeatureDescriptor feature = 1;
+  }
+  // Output locations for each feature of each table.
+  repeated TableDescriptor table = 1;
+
+  // Data layout and shape computation information for a single output tensor.
+  // Any unused locations in the tensor will be filled with zeros, and
+  // corresponding gradients will be ignored.
+
+  // Size and layout information for 2-D tensors.
+  message TwoDOutputTensor {
+    // Multiplier for output dimension 0 size; used to match legacy format that
+    // stacks features within a sample in dimension 0.
+    int32 dim0_size_per_sample = 2;
+
+    // The size (in dimension 1) of this output tensor.
+    int32 dim1_size = 1;
+  }
+
+  // Format information for a single output tensor.
+  message EmbeddingOutputTensor {
+    oneof output_format {
+      TwoDOutputTensor two_d = 4;
+    }
+  }
+
+  // Shape and layout information for each tensor.
+  repeated EmbeddingOutputTensor output = 2;
+}
diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py
index 471b1fa..b9e2a42 100644
--- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py
+++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py
@@ -72,13 +72,12 @@
         self._invert_topology(topology))
 
     topology_rank = self._topology_tasks.ndim
-    if core_assignment.ndim != topology_rank + 2:
-      raise ValueError("core_assignment must be a rank {} numpy array".format(
-          topology_rank + 2))
+    if core_assignment.ndim != 3:
+      raise ValueError("core_assignment must be a rank 3 numpy array, "
+                       "got shape {}".format(core_assignment.shape))
 
     self._num_replicas = core_assignment.shape[0]
-    self._computation_shape = np.array(
-        core_assignment.shape[1:-1], dtype=np.int32)
+    self._num_cores_per_replica = core_assignment.shape[1]
 
     if core_assignment.shape[-1] != topology_rank:
       raise ValueError(
@@ -107,18 +106,15 @@
     """Computes a nested dict which maps task and logical core to replicas."""
     task_and_cores_to_replicas = {}
     for replica in xrange(core_assignment.shape[0]):
-      for dx in xrange(core_assignment.shape[1]):
-        for dy in xrange(core_assignment.shape[2]):
-          for dz in xrange(core_assignment.shape[3]):
-            x, y, z = core_assignment[replica, dx, dy, dz, :]
-            task_id = topology_tasks[x, y, z]
-            if task_id not in task_and_cores_to_replicas:
-              task_and_cores_to_replicas[task_id] = {}
-            logical_core = (dx, dy, dz)
-            if logical_core not in task_and_cores_to_replicas[task_id]:
-              task_and_cores_to_replicas[task_id][logical_core] = set()
+      for logical_core in xrange(core_assignment.shape[1]):
+        x, y, z = core_assignment[replica, logical_core, :]
+        task_id = topology_tasks[x, y, z]
+        if task_id not in task_and_cores_to_replicas:
+          task_and_cores_to_replicas[task_id] = {}
+        if logical_core not in task_and_cores_to_replicas[task_id]:
+          task_and_cores_to_replicas[task_id][logical_core] = set()
 
-            task_and_cores_to_replicas[task_id][logical_core].add(replica)
+        task_and_cores_to_replicas[task_id][logical_core].add(replica)
 
     task_to_sorted_replica_id = {}
 
@@ -136,23 +132,9 @@
     return self._topology
 
   @property
-  def computation_shape(self):
-    """The computation shape.
-
-    Returns:
-      A rank-1 int32 numpy array with size equal to the TPU topology rank.
-      Describes the logical shape in numbers of core of each replica of the
-      computation in the TPU topology.
-
-    Returns:
-      The computation shape.
-    """
-    return self._computation_shape
-
-  @property
   def num_cores_per_replica(self):
     """The number of cores per replica."""
-    return np.prod(self.computation_shape)
+    return self._num_cores_per_replica
 
   @property
   def num_replicas(self):
@@ -164,33 +146,22 @@
     """The logical to physical core mapping.
 
     Returns:
-      A numpy array of rank `topology_rank + 2`, with shape
-      `[num_replicas] + computation_shape + [topology_rank]`. Maps
-      (replica, logical core coordinates) pairs to physical topology
-      coordinates.
+      An integer numpy array of rank 3, with shape
+      `[num_replicas, num_cores_per_replica, topology_rank]`. Maps
+      (replica, logical core) pairs to physical topology coordinates.
     """
     return self._core_assignment
 
   def _coordinates(self, replica, logical_core):
     """Returns the physical topology coordinates of a logical core."""
-    if logical_core is None:
-      logical_core = np.array([0, 0, 0], np.int32)
-    else:
-      logical_core = np.asarray(logical_core)
-
-    if any(logical_core < 0) or any(logical_core >= self.computation_shape):
-      raise ValueError("Invalid core {}; computation shape is {}".format(
-          logical_core, self.computation_shape))
-
-    logical_offset = tuple([replica] + logical_core.tolist() + [slice(3)])
-    return tuple(self.core_assignment[logical_offset])
+    return tuple(self.core_assignment[replica, logical_core, :])
 
   def lookup_replicas(self, task_id, logical_core):
     """Lookup replica ids by task number and logical core.
 
     Args:
       task_id: TensorFlow task number.
-      logical_core: A tuple of three integers which represents a logical core.
+      logical_core: An integer, identifying a logical core.
     Returns:
       A sorted list of the replicas that are attached to that task and
       logical_core.
@@ -205,17 +176,17 @@
           "Can not find any replica in task: {} contains logical_core: {} ".
           format(task_id, logical_core))
 
-  def tpu_ordinal(self, replica=0, logical_core=None):
+  def tpu_ordinal(self, replica=0, logical_core=0):
     """Returns the ordinal of the TPU device assigned to a logical core."""
     coordinates = self._coordinates(replica, logical_core)
     return self._topology_devices[coordinates]
 
-  def host_device(self, replica=0, logical_core=None, job=None):
+  def host_device(self, replica=0, logical_core=0, job=None):
     """Returns the CPU device attached to a logical core."""
     coordinates = self._coordinates(replica, logical_core)
     return _tpu_host_device_name(job, self._topology_tasks[coordinates])
 
-  def tpu_device(self, replica=0, logical_core=None, job=None):
+  def tpu_device(self, replica=0, logical_core=0, job=None):
     """Returns the name of the TPU device assigned to a logical core."""
     coordinates = self._coordinates(replica, logical_core)
     return _tpu_device_name(job, self._topology_tasks[coordinates],
@@ -228,6 +199,8 @@
                       num_replicas=1):
   """Computes a device_assignment of a computation across a TPU topology.
 
+  Attempts to choose a compact grid of cores for locality.
+
   Returns a `DeviceAssignment` that describes the cores in the topology assigned
   to each core of each replica.
 
@@ -240,12 +213,12 @@
       `initialize_system` using `Session.run`. Either a serialized
       `TopologyProto` or a `Topology` object may be passed. Note: you must
       evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
-    computation_shape: A rank 1 int32 numpy array of size 3, describing the
-      shape of the computation's block of cores. If None, the
-      `computation_shape` is `[1, 1, 1]`.
-    computation_stride: A rank 1 int32 numpy array of size 3, describing the
-      inter-core spacing of the `computation_shape` cores in the TPU topology.
-      If None, the `computation_stride` is `[1, 1, 1]`.
+    computation_shape: A rank 1 int32 numpy array with size equal to the
+      topology rank, describing the shape of the computation's block of cores.
+      If None, the `computation_shape` is `[1] * topology_rank`.
+    computation_stride: A rank 1 int32 numpy array of size `topology_rank`,
+      describing the inter-core spacing of the `computation_shape` cores in the
+      TPU topology. If None, the `computation_stride` is `[1] * topology_rank`.
     num_replicas: The number of computation replicas to run. The replicas will
       be packed into the free spaces of the topology.
 
@@ -271,21 +244,21 @@
   topology_rank = len(topology.mesh_shape)
   mesh_shape = topology.mesh_shape
   if computation_shape is None:
-    computation_shape = np.array([1, 1, 1], dtype=np.int32)
+    computation_shape = np.array([1] * topology_rank, dtype=np.int32)
   else:
     computation_shape = np.asarray(computation_shape, dtype=np.int32)
 
   if computation_stride is None:
-    computation_stride = np.array([1, 1, 1], dtype=np.int32)
+    computation_stride = np.array([1] * topology_rank, dtype=np.int32)
   else:
     computation_stride = np.asarray(computation_stride, dtype=np.int32)
 
-  if computation_shape.shape != (3,):
-    raise ValueError("computation_shape must have shape [3]; got {}".format(
-        computation_shape.shape))
-  if computation_stride.shape != (3,):
-    raise ValueError("computation_stride must have shape [3]; got {}".format(
-        computation_stride.shape))
+  if computation_shape.shape != (topology_rank,):
+    raise ValueError("computation_shape must have shape [{}]; got {}".format(
+        topology_rank, computation_shape.shape))
+  if computation_stride.shape != (topology_rank,):
+    raise ValueError("computation_stride must have shape [{}]; got {}".format(
+        topology_rank, computation_stride.shape))
 
   if any(computation_shape < 1):
     raise ValueError(
@@ -315,28 +288,41 @@
             num_replicas, max_replicas, computation_shape, computation_stride,
             mesh_shape))
 
-  # Choose a compact layout for the cores. Choose the smaller dimension in the
-  # topology to be close to the square root of the number of replicas.
-  num_chips = int(math.ceil(num_replicas / replica_counts[2]))
-  target_size = int(math.ceil(math.sqrt(num_chips)))
+  def ceil_of_ratio(n, m):
+    return (n + m - 1) // m
 
-  # Prefer an even size, if possible. Odd numbered rows head back towards the
-  # first column, so it's best if the last row has an odd index.
-  if target_size % 2 != 0:
-    target_size -= 1
-  y_size = min(replica_counts[1], target_size)
-  if y_size * replica_counts[0] < num_chips:
-    y_size = replica_counts[1]
+  replica_shape = [0] * topology_rank
+  if num_replicas > 0:
+    remaining_replicas = num_replicas
+    remaining_dims = topology_rank
+
+    # Choose dimensions as close to an equal cube as possible, in order of
+    # increasing dimension size. By visiting dimensions in increasing size, we
+    # assign the most constrained dimension first, so we won't make infeasible
+    # choices.
+    #
+    # As a secondary sort order, visit the dimensions in reverse order. This
+    # means we try to use both cores on the same chip in preference to two cores
+    # on different chips.
+    for x, ni in sorted(((x, -i) for (i, x) in enumerate(replica_counts))):
+      i = -ni
+      target_size = int(math.ceil(remaining_replicas**(1.0 / remaining_dims)))
+      replica_shape[i] = min(target_size, x)
+      remaining_replicas = ceil_of_ratio(remaining_replicas, replica_shape[i])
+      remaining_dims -= 1
+
+    assert remaining_replicas == 1 and remaining_dims == 0
 
   # Assigns an offset to each replica such that no two replicas overlap.
-  replica_offsets = np.full([num_replicas, 3], -1, dtype=np.int32)
+  replica_offsets = np.full([num_replicas, topology_rank], -1, dtype=np.int32)
   for replica in xrange(num_replicas):
-    # Chooses a replica number in X/Y/Z axes.
-    z = replica % replica_counts[2]
-    t = replica // replica_counts[2]
-    y = t % y_size
-    x = t // y_size
-    replica_pos = np.array([x, y, z], dtype=np.int32)
+    # Chooses a replica number in each axis.
+    t = replica
+    pos = []
+    for dim in replica_shape[::-1]:
+      pos.append(t % dim)
+      t //= dim
+    replica_pos = np.array(pos[::-1], dtype=np.int32)
 
     # Determines where that replica starts in each axis.
     outer = replica_pos // computation_stride
@@ -351,6 +337,6 @@
   indices = np.concatenate(
       [i[..., np.newaxis] for i in np.meshgrid(*indices, indexing="ij")],
       axis=-1)
-  assignment = (
-      indices + replica_offsets[:, np.newaxis, np.newaxis, np.newaxis, :])
+  indices = indices.reshape((-1, topology_rank))
+  assignment = indices + replica_offsets[:, np.newaxis, :]
   return DeviceAssignment(topology, core_assignment=assignment)
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index d8c3872..bf44525 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -76,6 +76,7 @@
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import callbacks as cbks
+from tensorflow.python.keras import metrics as metrics_module
 from tensorflow.python.keras import models
 from tensorflow.python.keras import optimizers as keras_optimizers
 from tensorflow.python.keras.engine import base_layer
@@ -293,6 +294,16 @@
     return KerasCrossShardOptimizer(opt)
 
 
+def clone_metrics(metrics):
+  """Returns a copy of metrics. A copy is created for stateful metrics."""
+  if metrics is None:
+    return None
+  return [
+      m.__class__.from_config(m.get_config())
+      if isinstance(m, metrics_module.Metric) else m for m in metrics
+  ]
+
+
 class TPURewriteContext(object):
   """Prepare the environment for a Keras model during `tpu.rewrite`.
 
@@ -811,8 +822,8 @@
             optimizer=_replicated_optimizer(cloned_optimizer),
             loss=self.model.loss,
             loss_weights=self.model.loss_weights,
-            metrics=self.model.metrics,
-            weighted_metrics=self.model.weighted_metrics,
+            metrics=clone_metrics(self.model.metrics),
+            weighted_metrics=clone_metrics(self.model.weighted_metrics),
             target_tensors=tpu_targets,
         )
 
@@ -970,15 +981,25 @@
       # Note: this condition is possible during the prologue or epilogue of the
       # pipelined loop.
       return None, None
-    # Strip sample weight from inputs
+
+    if (self.model.uses_learning_phase and
+        not isinstance(K.learning_phase(), int)):
+      # Remove the learning_phase flag at the end. We currently hard code the
+      # learning_phase in TPUFunction.
+      assert isinstance(inputs[-1], int), (
+          'Expect the final element be learning_phase flag. Got {}'.format(
+              inputs[-1]))
+      inputs = inputs[:-1]
+
     if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
         self.execution_mode == model_fn_lib.ModeKeys.EVAL):
+      # Strip sample weight from inputs.
       input_tensors = self.model._feed_inputs + self.model._feed_targets
-      inputs = inputs[:len(input_tensors)]
-      return input_tensors, inputs
     else:
       input_tensors = self.model._feed_inputs
-      return input_tensors, inputs
+
+    inputs = inputs[:len(input_tensors)]
+    return input_tensors, inputs
 
   def _process_outputs(self, outfeed_outputs):
     """Processes the outputs of a model function execution.
diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py
index 1fb26e7..ab89c6a 100644
--- a/tensorflow/contrib/tpu/python/tpu/topology.py
+++ b/tensorflow/contrib/tpu/python/tpu/topology.py
@@ -112,6 +112,11 @@
     return self._mesh_shape
 
   @property
+  def mesh_rank(self):
+    """Returns the number of dimensions in the mesh."""
+    return len(self._mesh_shape)
+
+  @property
   def device_coordinates(self):
     """Describes the mapping from TPU devices to topology coordinates.
 
@@ -125,6 +130,16 @@
     """
     return self._device_coordinates
 
+  @property
+  def num_tasks(self):
+    """Returns the number of TensorFlow tasks in the TPU slice."""
+    return self._device_coordinates.shape[0]
+
+  @property
+  def num_tpus_per_task(self):
+    """Returns the number of TPU devices per task in the TPU slice."""
+    return self._device_coordinates.shape[1]
+
   def serialized(self):
     """Returns the serialized form of the topology."""
     if self._serialized is None:
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 0f9f7cd..712b02f 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -26,6 +26,7 @@
 from tensorflow.contrib.tpu.python.tpu import tpu_function
 
 from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.compat import compat as api_compat
 from tensorflow.python.framework import device as pydev
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
@@ -75,7 +76,7 @@
   """Initializes a distributed TPU system for use with TensorFlow.
 
   Args:
-    embedding_config: If not None, an `EmbeddingLayerConfiguration` proto
+    embedding_config: If not None, a `TPUEmbeddingConfiguration` proto
       describing the desired configuration of the hardware embedding lookup
       tables. If embedding_config is None, no hardware embeddings can be used.
     job: The job (the XXX in TensorFlow device specification /job:XXX) that
@@ -558,10 +559,17 @@
         "topology":
             device_assignment.topology.serialized(),
         "device_assignment":
-            device_assignment.core_assignment.flatten().tolist(),
-        "computation_shape":
-            device_assignment.computation_shape.tolist()
+            device_assignment.core_assignment.flatten().tolist()
     }
+    # TODO(phawkins): remove this case after the forward compatibility window
+    # expires on 2018-10-5.
+    if api_compat.forward_compatible(2018, 10, 5):
+      metadata_kwargs["num_cores_per_replica"] = (
+          device_assignment.num_cores_per_replica)
+    else:
+      metadata_kwargs["computation_shape"] = [
+          device_assignment.num_cores_per_replica
+      ]
 
   if ((not isinstance(inputs, list)) or
       any(not isinstance(inp, (list, tuple)) for inp in inputs)):
@@ -840,8 +848,12 @@
   if num_shards <= 0:
     raise ValueError("num_shards must be a positive integer.")
 
+  inputs = [] if inputs is None else inputs
+  if not isinstance(inputs, list):
+    raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.")
+
   # Converts inputs to Tensors.
-  inputs = [] if inputs is None else [ops.convert_to_tensor(x) for x in inputs]
+  inputs = [ops.convert_to_tensor(x) for x in inputs]
 
   if input_shard_axes is None:
     input_shard_axes = [0] * len(inputs)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 18e0abd..9f8d147 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -32,7 +32,6 @@
 _TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
 _SERVICE_KEY = run_config_lib._SERVICE_KEY
 _TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
-_NUM_CORES_PER_HOST = 8
 # pylint: enable=protected-access
 
 
@@ -103,7 +102,7 @@
       input mode.
 
     Raises:
-      ValueError: If `num_cores_per_replica` is not 1, 2, 4 or 8.
+      ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16.
   """
 
   def __new__(cls,
@@ -139,9 +138,9 @@
 
     # Check num_cores_per_replica
     if num_cores_per_replica is not None:
-      if num_cores_per_replica not in [1, 2, 4, 8]:
+      if num_cores_per_replica not in [1, 2, 4, 8, 16]:
         raise ValueError(
-            'num_cores_per_replica must be 1, 2, 4, or 8; got {}'.format(
+            'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format(
                 str(num_cores_per_replica)))
 
     # per_host_input_for_training may be True, False, or integer in [1..3].
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
index 2326fe9..b2fe0a6 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
@@ -86,7 +86,7 @@
 
   def test_fail_with_invalid_num_cores_per_replica(self):
     with self.assertRaisesRegexp(
-        ValueError, 'num_cores_per_replica must be 1, 2, 4, or 8;'
+        ValueError, 'num_cores_per_replica must be 1, 2, 4, 8, or 16;'
         ' got 7'):
       tpu_config_lib.TPUConfig(num_cores_per_replica=7)
 
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 19359cb..b1a8a16 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -35,7 +35,8 @@
     1: [1, 1, 1],
     2: [1, 1, 2],
     4: [1, 2, 2],
-    8: [2, 2, 2]
+    8: [2, 2, 2],
+    16: [4, 2, 2],
 }
 
 
@@ -298,6 +299,7 @@
 
   @property
   def num_of_replicas_per_host(self):
+    """Return the number of replicas per host."""
     if self.model_parallelism_enabled:
       return self.num_replicas // self.num_hosts
     else:
@@ -538,8 +540,8 @@
       """
       if self.model_parallelism_enabled:
         # We put both enqueue/dequeue ops at tpu.core(0) in each replica.
-        replica = self.device_assignment.lookup_replicas(
-            host_id, (0, 0, 0))[shard_index_in_host]
+        replica = self.device_assignment.lookup_replicas(host_id,
+                                                         0)[shard_index_in_host]
         return self.device_assignment.tpu_ordinal(replica=replica)
       else:
         return shard_index_in_host % self.num_of_cores_per_host
@@ -580,6 +582,17 @@
 
         raise ValueError(message)
 
+    if self._config.tpu_config.num_cores_per_replica:
+      num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
+      num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host
+      if num_cores_per_replica > num_cores_per_host:
+        raise ValueError(
+            'The num of cores required by the model parallelism, specified by '
+            'TPUConfig.num_cores_per_replica, is larger than the '
+            'num_cores_per_host. num_cores_per_replica: {}, '
+            'num_cores_per_host: {}'.format(num_cores_per_replica,
+                                            num_cores_per_host))
+
     if mode == model_fn_lib.ModeKeys.TRAIN:
       if (self._train_batch_size % num_replicas != 0 and
           not self.is_input_broadcast_with_iterators()):
@@ -599,8 +612,8 @@
             .format(self._eval_batch_size, num_replicas))
       if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
         raise ValueError(
-            'TPUEstimator.evaluate should be running on single TPU worker. '
-            'got {}.'.format(num_hosts))
+            'TPUEstimator.evaluate should be running on single TPU'
+            ' instead of a Pod.')
     else:
       assert mode == model_fn_lib.ModeKeys.PREDICT
       if self._predict_batch_size is None:
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 1ff04f5..23c5451 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -1774,18 +1774,19 @@
         summary_writer=summary_writer)
 
   def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
-    global_step_per_sec = elapsed_steps / elapsed_time
-    examples_per_sec = self._batch_size * global_step_per_sec
+    global_steps_per_sec = elapsed_steps / elapsed_time
+    examples_per_sec = self._batch_size * global_steps_per_sec
     if self._summary_writer is not None:
       global_step_summary = Summary(value=[
-          Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec)
+          Summary.Value(tag='global_steps/sec',
+                        simple_value=global_steps_per_sec)
       ])
       example_summary = Summary(value=[
           Summary.Value(tag='examples/sec', simple_value=examples_per_sec)
       ])
       self._summary_writer.add_summary(global_step_summary, global_step)
       self._summary_writer.add_summary(example_summary, global_step)
-    logging.info('global_step/sec: %g', global_step_per_sec)
+    logging.info('global_steps/sec: %g', global_steps_per_sec)
     logging.info('examples/sec: %g', examples_per_sec)
 
 
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
index d9c77a3..e75a094 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
@@ -765,9 +765,8 @@
           zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat)
       ]
 
-      for core_index in xrange(self._device_assignment.num_cores_per_replica):
+      for logical_core in xrange(self._device_assignment.num_cores_per_replica):
         # Places different partitions to different logic cores.
-        logical_core = self._get_logical_core(core_index)
         replica_id = self._device_assignment.lookup_replicas(
             self._host_id, logical_core)[replica_index]
         ordinal = self._device_assignment.tpu_ordinal(
@@ -784,7 +783,7 @@
                   inputs=infeed_inputs,
                   shapes=[x.shape for x in infeed_inputs],
                   name="enqueue/replica_{0}/input_{1}".format(
-                      replica_index, core_index),
+                      replica_index, logical_core),
                   device_ordinal=ordinal))
     return per_host_enqueue_ops
 
@@ -890,20 +889,3 @@
     return nest.map_structure_up_to(
         dequeues, self._tag_sharding_attribute_for_dequeued_tensor, dequeues,
         dims)
-
-  def _get_logical_core(self, core_index):
-    """Maps the core index to the 3D coordinate within replica.
-
-      The lowest dimension number in computation_shape is the slowest varying
-      dimension (most major).
-
-    Args:
-      core_index: An integer represents the core index within replcia.
-
-    Returns:
-      A tuple with three integers which represents the 3D coordinate.
-    """
-    computation_shape = self._device_assignment.computation_shape
-    return (core_index // (computation_shape[1] * computation_shape[2]),
-            core_index % (computation_shape[1] * computation_shape[2]) //
-            computation_shape[2], core_index % computation_shape[2])
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_function.py b/tensorflow/contrib/tpu/python/tpu/tpu_function.py
index de16e3b..0c7a38d 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_function.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_function.py
@@ -63,10 +63,9 @@
   """Validate the number of input arguments to a tpu function.
 
   Args:
-    func: the Python function that will be called to generate the body
-      of a TPUFunction.
-    input_arity: the number of explicit arguments supplied by the
-      caller.
+    func: the Python function that will be called to generate the body of an XLA
+      computation graph.
+    input_arity: the number of explicit arguments supplied by the caller.
     infeed_queue: if not None, the infeed queue that will supply
       additional arguments to the function.
 
@@ -103,4 +102,3 @@
   # Since there are varargs, func can accept any number of arguments
   # greater than the minimum.
   return None
-
diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc
index 3cb5e61..2784bf1 100644
--- a/tensorflow/contrib/verbs/rdma_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_mgr.cc
@@ -20,7 +20,6 @@
 #include <vector>
 #include "tensorflow/contrib/verbs/grpc_verbs_client.h"
 #include "tensorflow/contrib/verbs/verbs_service.pb.h"
-#include "tensorflow/core/common_runtime/bfc_allocator.h"
 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
 #include "tensorflow/core/common_runtime/gpu/gpu_util.h"
 #include "tensorflow/core/common_runtime/pool_allocator.h"
@@ -29,6 +28,7 @@
 #include "tensorflow/core/distributed_runtime/session_mgr.h"
 #include "tensorflow/core/framework/allocator_registry.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
 
 namespace tensorflow {
 
@@ -256,74 +256,41 @@
   }
 }
 
-// TODO(byronyi): remove this class and its registration when the default
-// cpu_allocator() returns visitable allocator, or cpu_allocator() is no
-// longer in use.
-class BFCRdmaAllocator : public BFCAllocator {
- public:
-  BFCRdmaAllocator()
-      : BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36,
-                     true, "cpu_rdma_bfc") {}
-};
-class BFCRdmaAllocatorFactory : public AllocatorFactory {
- public:
-  Allocator* CreateAllocator() { return new BFCRdmaAllocator; }
-
-  SubAllocator* CreateSubAllocator(int numa_node) {
-    return new BasicCPUAllocator(numa_node);
-  }
-};
-
-REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocatorFactory);
-
 void RdmaMgr::InitAllocators() {
-  RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_;
+  static std::once_flag flag;
+  std::call_once(
+      flag, [this]() { RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_; });
+}
 
-  Allocator* allocators[] = {
-#if GOOGLE_CUDA
-    GPUProcessState::singleton()->GetCUDAHostAllocator(0),
-#endif  // GOOGLE_CUDA
-    ProcessState::singleton()->GetCPUAllocator(0),
-    cpu_allocator(),
+/*static*/ void RdmaMgr::RegMemVisitors() {
+  SubAllocator::Visitor alloc_visitor = [](void* ptr, int numa_node,
+                                           size_t num_bytes) {
+    RdmaMemoryMgr::Singleton().InsertMemoryRegion(
+        ptr, num_bytes, strings::StrCat("CPU:", numa_node));
+  };
+  SubAllocator::Visitor free_visitor = [](void* ptr, int numa_node,
+                                          size_t num_bytes) {
+    RdmaMemoryMgr::Singleton().EvictMemoryRegion(ptr, num_bytes);
   };
 
-  using namespace std::placeholders;
-
-  std::set<Allocator*> instrumented_;
-
-  // Host memory allocators
-  for (Allocator* allocator : allocators) {
-    VisitableAllocator::Visitor alloc_visitor =
-        std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
-                  &RdmaMemoryMgr::Singleton(), _1, _2, allocator->Name());
-    VisitableAllocator::Visitor free_visitor = std::bind(
-        &RdmaMemoryMgr::EvictMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2);
-
-    auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
-    CHECK(visitable_allocator)
-        << "is not visitable for instrumentation" << allocator->Name();
-    // Make sure we don't instrument the same allocator twice
-    if (instrumented_.find(allocator) == std::end(instrumented_)) {
-      visitable_allocator->AddAllocVisitor(alloc_visitor);
-      visitable_allocator->AddFreeVisitor(free_visitor);
-      instrumented_.insert(allocator);
-      LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
-    }
-  }
+  ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor);
+  ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
 
 #if GOOGLE_CUDA
   if (IsGDRAvailable()) {
     // Note we don't free allocated GPU memory so there is no free visitor
     int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1;
 
-    char buf[8];
-    sprintf(buf, "gpu");
-    VisitableAllocator::Visitor cuda_alloc_visitor =
-        std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
-                  &RdmaMemoryMgr::Singleton(), _1, _2, std::string(buf));
-
+    SubAllocator::Visitor cuda_alloc_visitor = [](void* ptr, int gpu_id,
+                                                  size_t num_bytes) {
+      RdmaMemoryMgr::Singleton().InsertMemoryRegion(
+          ptr, num_bytes, strings::StrCat("GPU:", gpu_id));
+    };
     GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
                                                      cuda_alloc_visitor);
+    GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id,
+                                                          alloc_visitor);
+    GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor);
     LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
   }
 #endif  // GOOGLE_CUDA
diff --git a/tensorflow/contrib/verbs/rdma_mgr.h b/tensorflow/contrib/verbs/rdma_mgr.h
index 9fffc33..74b92cc 100644
--- a/tensorflow/contrib/verbs/rdma_mgr.h
+++ b/tensorflow/contrib/verbs/rdma_mgr.h
@@ -39,6 +39,7 @@
   void SetupChannels();
   bool ConnectivityCheck();
   void InitAllocators();
+  static void RegMemVisitors();
   const string& local_worker() { return local_worker_; }
 
  private:
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc
index 1a0b502..6146968 100644
--- a/tensorflow/contrib/verbs/verbs_server_lib.cc
+++ b/tensorflow/contrib/verbs/verbs_server_lib.cc
@@ -76,8 +76,13 @@
   return Status::OK();
 }
 
+namespace {
+std::once_call reg_mem_visitors_call;
+}  // namespace
+
 Status VerbsServer::Init(ServiceInitFunction service_func,
                          RendezvousMgrCreationFunction rendezvous_mgr_func) {
+  std::call_once(reg_mem_visitors_call, []() { RdmaMgr::RegMemVisitors(); });
   Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
   {
     mutex_lock l(mu_);
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 8f32bc2..85b6d4f 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -85,11 +85,12 @@
     "tf_cc_tests",
     "tf_copts",
     "tf_cuda_library",
+    "tf_features_nomodules_if_android",
     "tf_gen_op_libs",
     "tf_generate_proto_text_sources",
     "tf_genrule_cmd_append_to_srcs",
     "tf_opts_nortti_if_android",
-    "tf_features_nomodules_if_android",
+    "transitive_hdrs",
 )
 load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
 load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
@@ -120,16 +121,16 @@
     "tf_additional_libdevice_srcs",
     "tf_additional_minimal_lib_srcs",
     "tf_additional_mpi_lib_defines",
-    "tf_additional_proto_hdrs",
     "tf_additional_proto_compiler_hdrs",
+    "tf_additional_proto_hdrs",
     "tf_additional_proto_srcs",
     "tf_additional_test_deps",
     "tf_additional_test_srcs",
     "tf_additional_verbs_lib_defines",
     "tf_jspb_proto_library",
     "tf_kernel_tests_linkstatic",
-    "tf_lib_proto_parsing_deps",
     "tf_lib_proto_compiler_deps",
+    "tf_lib_proto_parsing_deps",
     "tf_nano_proto_library",
     "tf_platform_hdrs",
     "tf_platform_srcs",
@@ -178,7 +179,6 @@
     "framework/iterator.proto",
     "framework/kernel_def.proto",
     "framework/log_memory.proto",
-    "framework/model.proto",
     "framework/node_def.proto",
     "framework/op_def.proto",
     "framework/reader_base.proto",
@@ -842,7 +842,6 @@
         "framework/log_memory.h",
         "framework/lookup_interface.h",
         "framework/memory_types.h",
-        "framework/model.h",
         "framework/node_def_builder.h",
         "framework/node_def_util.h",
         "framework/numeric_op.h",
@@ -1068,7 +1067,6 @@
         "spectral_ops",
         "state_ops",
         "stateless_random_ops",
-        "string_ops",
         "summary_ops",
         "training_ops",
     ],
@@ -1076,6 +1074,13 @@
 
 tf_gen_op_libs(
     op_lib_names = [
+        "string_ops",
+    ],
+    deps = ["@com_google_absl//absl/strings"],
+)
+
+tf_gen_op_libs(
+    op_lib_names = [
         "array_ops",
     ],
     deps = [":protos_all_cc"],
@@ -1331,6 +1336,7 @@
         "//tensorflow/core/kernels:rpc_op",
         "//tensorflow/core/kernels:scoped_allocator_ops",
         "//tensorflow/core/kernels:sdca_ops",
+        "//tensorflow/core/kernels:searchsorted_op",
         "//tensorflow/core/kernels:set_kernels",
         "//tensorflow/core/kernels:sparse",
         "//tensorflow/core/kernels:state",
@@ -1429,9 +1435,11 @@
         ":test",
         ":testlib_ops",
         "//tensorflow/cc:scope",
+        "//tensorflow/core/kernels:cast_op",
         "//tensorflow/core/kernels:constant_op",
         "//tensorflow/core/kernels:ops_testutil",
         "//tensorflow/core/kernels:ops_util",
+        "//tensorflow/core/kernels:random_ops",
     ],
 )
 
@@ -1921,6 +1929,13 @@
 )
 
 tf_pyclif_proto_library(
+    name = "protobuf/config_pyclif",
+    proto_lib = ":protos_all_cc",
+    proto_srcfile = "protobuf/config.proto",
+    visibility = ["//visibility:public"],
+)
+
+tf_pyclif_proto_library(
     name = "protobuf/device_properties_pyclif",
     proto_lib = ":protos_all_cc",
     proto_srcfile = "protobuf/device_properties.proto",
@@ -2087,6 +2102,7 @@
     deps = tf_additional_lib_deps() + [
         "@com_google_absl//absl/strings",
         "//third_party/eigen3",
+        "@com_google_absl//absl/base:core_headers",
         "//tensorflow/core/platform/default/build_config:platformlib",
     ] + if_static([":lib_internal_impl"]),
 )
@@ -2279,6 +2295,7 @@
     deps = [
         "//tensorflow/core/platform/default/build_config:jpeg",
         "//tensorflow/core/platform/default/build_config:logging",
+        "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -2311,6 +2328,7 @@
     deps = [
         "//tensorflow/core/platform/default/build_config:gif",
         "//tensorflow/core/platform/default/build_config:logging",
+        "@com_google_absl//absl/base:core_headers",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -2775,7 +2793,6 @@
     "common_runtime/step_stats_collector.h",
     "common_runtime/threadpool_device.h",
     "common_runtime/tracing_device.h",
-    "common_runtime/visitable_allocator.h",
     "common_runtime/process_state.h",
     "common_runtime/pool_allocator.h",
     "graph/gradients.h",
@@ -2971,12 +2988,16 @@
     ] + tf_additional_device_tracer_deps(),
 )
 
-cc_library(
-    name = "session_ref",
-    srcs = ["common_runtime/session_ref.cc"],
-    hdrs = ["common_runtime/session_ref.h"],
-    copts = tf_copts(),
-    deps = [":core_cpu_base"],
+tf_proto_library_cc(
+    name = "replay_log_proto",
+    srcs = ["protobuf/replay_log.proto"],
+    cc_api_version = 2,
+    protodeps = [
+        ":master_proto",
+    ] + tf_additional_all_protos(),
+    visibility = [
+        "//tensorflow:internal",
+    ],
 )
 
 cc_library(
@@ -4713,6 +4734,18 @@
     ] + tf_additional_libdevice_deps(),
 )
 
+transitive_hdrs(
+    name = "headers",
+    visibility = ["//tensorflow:__subpackages__"],
+    deps = [
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:stream_executor",
+    ],
+)
+
 # -----------------------------------------------------------------------------
 # Google-internal targets go here (must be at the end).
 
diff --git a/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt
new file mode 100644
index 0000000..5ce825a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt
@@ -0,0 +1,45 @@
+op {
+  graph_op_name: "LowerBound"
+  visibility: HIDDEN
+  in_arg {
+    name: "sorted_inputs"
+    description: <<END
+2-D Tensor where each row is ordered.
+END
+  }
+  in_arg {
+    name: "values"
+    description: <<END
+2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+the values that will be searched for in `sorted_search_values`.
+END
+  }
+  out_arg {
+    name: "output"
+    description: <<END
+A `Tensor` with the same shape as `values`.  It contains the first scalar index
+into the last dimension where values can be inserted without changing the
+ordered property.
+END
+  }
+  summary: "Applies lower_bound(sorted_search_values, values) along each row."
+  description: <<END
+Each set of rows with the same index in (sorted_inputs, values) is treated
+independently.  The resulting row is the equivalent of calling
+`np.searchsorted(sorted_inputs, values, side='left')`.
+
+The result is not a global index to the entire 
+`Tensor`, but rather just the index in the last dimension.
+
+A 2-D example:
+  sorted_sequence = [[0, 3, 9, 9, 10],
+                     [1, 2, 3, 4, 5]]
+  values = [[2, 4, 9],
+            [0, 2, 6]]
+
+  result = LowerBound(sorted_sequence, values)
+
+  result == [[1, 2, 2],
+             [0, 1, 5]]
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000..4cb8955
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,19 @@
+op {
+  graph_op_name: "PrintV2"
+  in_arg {
+    name: "input"
+    description: <<END
+The string scalar to print.
+END
+  }
+  attr {
+    name: "output_stream"
+    description: <<END
+A string specifying the output stream or logging level to print to.
+END
+  }
+  summary: "Prints a string scalar."
+  description: <<END
+Prints a string scalar to the desired output_stream.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000..a82dae9
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,38 @@
+op {
+  graph_op_name: "StringFormat"
+  in_arg {
+    name: "inputs"
+    description: <<END
+The list of tensors to format into the placeholder string.
+END
+  }
+
+  out_arg {
+    name: "output"
+    description: <<END
+= The resulting string scalar.
+END
+  }
+  attr {
+    name: "template"
+    description: <<END
+A string, the template to format tensor summaries into.
+END
+  }
+  attr {
+    name: "placeholder"
+    description: <<END
+A string, at each placeholder in the template a subsequent tensor summary will be inserted.
+END
+  }
+  attr {
+    name: "summarize"
+    description: <<END
+When formatting the tensor summaries print the first and last summarize entries of each tensor dimension.
+END
+  }
+  summary: "Formats a string template using a list of tensors."
+  description: <<END
+Formats a string template using a list of tensors, pretty-printing tensor summaries.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt b/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt
new file mode 100644
index 0000000..0630f6e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt
@@ -0,0 +1,45 @@
+op {
+  graph_op_name: "UpperBound"
+  visibility: HIDDEN
+  in_arg {
+    name: "sorted_inputs"
+    description: <<END
+2-D Tensor where each row is ordered.
+END
+  }
+  in_arg {
+    name: "values"
+    description: <<END
+2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+the values that will be searched for in `sorted_search_values`.
+END
+  }
+  out_arg {
+    name: "output"
+    description: <<END
+A `Tensor` with the same shape as `values`.  It contains the last scalar index
+into the last dimension where values can be inserted without changing the
+ordered property.
+END
+  }
+  summary: "Applies upper_bound(sorted_search_values, values) along each row."
+  description: <<END
+Each set of rows with the same index in (sorted_inputs, values) is treated
+independently.  The resulting row is the equivalent of calling
+`np.searchsorted(sorted_inputs, values, side='right')`.
+
+The result is not a global index to the entire 
+`Tensor`, but rather just the index in the last dimension.
+
+A 2-D example:
+  sorted_sequence = [[0, 3, 9, 9, 10],
+                     [1, 2, 3, 4, 5]]
+  values = [[2, 4, 9],
+            [0, 2, 6]]
+
+  result = UpperBound(sorted_sequence, values)
+
+  result == [[1, 2, 4],
+             [0, 2, 5]]
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
index 1bc3660..01387b7 100644
--- a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
@@ -2,10 +2,31 @@
   visibility: HIDDEN
   graph_op_name: "WindowDataset"
   in_arg {
-    name: "window_size"
+    name: "size"
     description: <<END
 A scalar representing the number of elements to accumulate in a window.
 END
   }
+  in_arg {
+    name: "shift"
+    description: <<END
+A scalar representing the steps moving the sliding window forward in one
+iteration. It must be positive.
+END
+  }
+  in_arg {
+    name: "stride"
+    description: <<END
+A scalar representing the stride of the input elements of the sliding window.
+It must be positive.
+END
+  }
+  in_arg {
+    name: "drop_remainder"
+    description: <<END
+A scalar representing whether a window should be dropped in case its size is
+smaller than desired.
+END
+  }
   summary: "A dataset that creates window datasets from the input dataset."
 }
diff --git a/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000..e22d980
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "PrintV2"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000..8f0b1db
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "StringFormat"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 84c6285..3843ea9 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -31,7 +31,7 @@
 
 BFCAllocator::BFCAllocator(SubAllocator* sub_allocator, size_t total_memory,
                            bool allow_growth, const string& name)
-    : suballocator_(sub_allocator),
+    : sub_allocator_(sub_allocator),
       name_(name),
       free_chunks_list_(kInvalidChunkHandle),
       next_allocation_id_(1) {
@@ -72,7 +72,7 @@
   VLOG(2) << "Number of regions allocated: "
           << region_manager_.regions().size();
   for (const auto& region : region_manager_.regions()) {
-    suballocator_->Free(region.ptr(), region.memory_size());
+    sub_allocator_->Free(region.ptr(), region.memory_size());
   }
 
   for (BinNum b = 0; b < kNumBins; b++) {
@@ -108,7 +108,7 @@
 
   // Try allocating.
   size_t bytes = std::min(curr_region_allocation_bytes_, available_bytes);
-  void* mem_addr = suballocator_->Alloc(alignment, bytes);
+  void* mem_addr = sub_allocator_->Alloc(alignment, bytes);
   if (mem_addr == nullptr && !started_backpedal_) {
     // Only backpedal once.
     started_backpedal_ = true;
@@ -119,7 +119,7 @@
     while (mem_addr == nullptr) {
       bytes = RoundedBytes(bytes * kBackpedalFactor);
       if (bytes < rounded_bytes) break;
-      mem_addr = suballocator_->Alloc(alignment, bytes);
+      mem_addr = sub_allocator_->Alloc(alignment, bytes);
     }
   }
 
@@ -158,10 +158,6 @@
   // Insert the chunk into the right bin.
   InsertFreeChunkIntoBin(h);
 
-  // Invoke visitors on newly allocated region.
-  for (const auto& visitor : region_visitors_) {
-    visitor(mem_addr, bytes);
-  }
   return true;
 }
 
@@ -490,15 +486,6 @@
   InsertFreeChunkIntoBin(coalesced_chunk);
 }
 
-void BFCAllocator::AddAllocVisitor(Visitor visitor) {
-  VLOG(1) << "AddVisitor";
-  mutex_lock l(lock_);
-  region_visitors_.push_back(visitor);
-  for (const auto& region : region_manager_.regions()) {
-    visitor(region.ptr(), region.memory_size());
-  }
-}
-
 bool BFCAllocator::TracksAllocationSizes() { return true; }
 
 size_t BFCAllocator::RequestedSize(const void* ptr) {
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
index 20e1dab..2d74bf2 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.h
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -23,7 +23,7 @@
 #include <vector>
 
 #include "tensorflow/core/common_runtime/allocator_retry.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/lib/gtl/stl_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/macros.h"
@@ -42,7 +42,7 @@
 // coalescing.  One assumption we make is that the process using this
 // allocator owns pretty much all of the memory, and that nearly
 // all requests to allocate memory go through this interface.
-class BFCAllocator : public VisitableAllocator {
+class BFCAllocator : public Allocator {
  public:
   // Takes ownership of sub_allocator.
   BFCAllocator(SubAllocator* sub_allocator, size_t total_memory,
@@ -55,11 +55,6 @@
                     const AllocationAttributes& allocation_attr) override;
   void DeallocateRaw(void* ptr) override;
 
-  void AddAllocVisitor(Visitor visitor) override;
-
-  // Does nothing, because memory is never freed.
-  void AddFreeVisitor(Visitor visitor) override {}
-
   bool TracksAllocationSizes() override;
 
   size_t RequestedSize(const void* ptr) override;
@@ -309,7 +304,7 @@
   };
 
   // Returns 'bytes' rounded up to the next highest kMinAllocationSize.
-  size_t RoundedBytes(size_t bytes);
+  static size_t RoundedBytes(size_t bytes);
 
   // Try to add a new memory region that can satisfy an allocation of
   // 'rounded_bytes' bytes.  Returns true on success and false on
@@ -423,7 +418,7 @@
   // of the available memory.
   bool started_backpedal_ = false;
 
-  std::unique_ptr<SubAllocator> suballocator_;
+  std::unique_ptr<SubAllocator> sub_allocator_;
   string name_;
 
   // Structures mutable after construction
@@ -435,9 +430,6 @@
   // Pointer to head of linked list of free Chunks
   ChunkHandle free_chunks_list_ GUARDED_BY(lock_);
 
-  // Called once on each region, ASAP.
-  std::vector<Visitor> region_visitors_ GUARDED_BY(lock_);
-
   // Counter containing the next unique identifier to assign to a
   // newly-created chunk.
   int64 next_allocation_id_ GUARDED_BY(lock_);
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index 81d68e3..fb76d6a 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -106,6 +106,10 @@
   // at completion.
   virtual Status Sync() = 0;
 
+  // Override this to return true for devices that require a Sync() call before
+  // session completion.
+  virtual bool RequiresSyncOnCompletion() const { return false; }
+
   // Optionally modify the device's GraphDef before execution.
   //
   // This method should be considered experimental and is supplied to enable
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index b4d8e28..af5d5b1 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1202,14 +1202,11 @@
     auto opseg = device->op_segment();
     params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
                                               OpKernel** kernel) {
-      // We do not share the kernel via the OpSegment if the node is
-      // stateless, or a function.
       // NOTE(mrry): We must not share function kernels (implemented
       // using `CallOp`) between subgraphs, because `CallOp::handle_`
       // is tied to a particular subgraph. Even if the function itself
       // is stateful, the `CallOp` that invokes it is not.
-      if (!lib->IsStateful(ndef.op()) ||
-          lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
+      if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
         return lib->CreateKernel(ndef, kernel);
       }
       auto create_fn = [lib, &ndef](OpKernel** kernel) {
@@ -1222,10 +1219,8 @@
                                  create_fn);
     };
     params.delete_kernel = [lib](OpKernel* kernel) {
-      // If the node is stateful, opseg owns it. Otherwise, delete it.
-      if (kernel && !lib->IsStateful(kernel->type_string())) {
+      if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
         delete kernel;
-      }
     };
 
     optimizer.Optimize(lib, options_.env, device, &partition_graph,
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 263467a..18420b6 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -32,6 +32,18 @@
   return default_val;
 }
 
+std::unique_ptr<thread::ThreadPool> EagerThreadPool(
+    const SessionOptions& opts) {
+  SessionOptions opts_copy(opts);
+  if (opts_copy.config.inter_op_parallelism_threads() == 0) {
+    // Eager defaults to a single thread when no threads are specified.
+    opts_copy.config.set_inter_op_parallelism_threads(1);
+  }
+
+  return std::unique_ptr<thread::ThreadPool>(
+      NewThreadPoolFromSessionOptions(opts_copy));
+}
+
 }  // namespace
 
 EagerContext::EagerContext(const SessionOptions& opts,
@@ -49,7 +61,7 @@
     : policy_(default_policy),
       devices_(device_mgr->ListDevices()),
       rendezvous_(rendezvous),
-      thread_pool_(NewThreadPoolFromSessionOptions(opts)),
+      thread_pool_(EagerThreadPool(opts)),
       pflr_(new ProcessFunctionLibraryRuntime(
           device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {},
           thread_pool_.get())),
@@ -67,7 +79,7 @@
   }
   InitDeviceMapAndAsync();
   runner_ = [this](std::function<void()> closure) {
-    this->thread_pool_->Schedule(closure);
+    this->thread_pool_->Schedule(std::move(closure));
   };
 }
 
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 1da1326..1bc6361 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -251,26 +251,6 @@
   EagerContext* ctx = op->EagerContext();
   auto status = ctx->GetStatus();
   if (!status.ok()) return status;
-  // Ensure all resource-touching ops run in the device the resource is,
-  // regardless of anything else that has been specified. This is identical to
-  // the graph mode behavior.
-  for (int i = 0; i < op->Inputs().size(); ++i) {
-    Device* input_op_device = nullptr;
-    status = op->Inputs()[i]->OpDevice(&input_op_device);
-    if (!status.ok()) return status;
-    VLOG(2) << "for op " << op->Name() << " input " << i << " "
-            << DataTypeString(op->Inputs()[i]->dtype) << " "
-            << (input_op_device == nullptr ? "cpu" : input_op_device->name())
-            << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name());
-    if (op->Inputs()[i]->dtype == DT_RESOURCE &&
-        (input_op_device != op->Device() || input_op_device == nullptr)) {
-      Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device;
-      VLOG(1) << "Changing device of operation " << op->Name() << " to "
-              << d->name() << " because input #" << i
-              << " is a resource in this device.";
-      op->SetDevice(d);
-    }
-  }
   Device* device = op->Device();
 
   Fprint128 cache_key = op->MutableAttrs()->CacheKey(
@@ -604,6 +584,27 @@
 Status EagerExecute(EagerOperation* op,
                     gtl::InlinedVector<TensorHandle*, 2>* retvals,
                     int* num_retvals) {
+  // Ensure all resource-touching ops run in the device the resource is,
+  // regardless of anything else that has been specified. This is identical to
+  // the graph mode behavior.
+  EagerContext* ctx = op->EagerContext();
+  for (int i = 0; i < op->Inputs().size(); ++i) {
+    Device* input_op_device = nullptr;
+    auto status = op->Inputs()[i]->OpDevice(&input_op_device);
+    if (!status.ok()) return status;
+    VLOG(2) << "for op " << op->Name() << " input " << i << " "
+            << DataTypeString(op->Inputs()[i]->dtype) << " "
+            << (input_op_device == nullptr ? "cpu" : input_op_device->name())
+            << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name());
+    if (op->Inputs()[i]->dtype == DT_RESOURCE &&
+        (input_op_device != op->Device() || input_op_device == nullptr)) {
+      Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device;
+      VLOG(1) << "Changing device of operation " << op->Name() << " to "
+              << d->name() << " because input #" << i
+              << " is a resource in this device.";
+      op->SetDevice(d);
+    }
+  }
   bool op_is_local = IsLocal(op->EagerContext(), op->Device());
 
   if (op_is_local) {
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index b912f7d..d58724c 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -125,7 +125,6 @@
 Status TensorHandle::NumDims(int* num_dims) {
   if (IsRemote()) {
     TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
-    CHECK(remote_shape_ != nullptr);
     *num_dims = remote_shape_->dims();
   } else {
     TF_RETURN_IF_ERROR(WaitReady());
@@ -153,6 +152,21 @@
   return Status::OK();
 }
 
+Status TensorHandle::NumElements(int64* num_elements) {
+  if (IsRemote()) {
+    TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+    *num_elements = remote_shape_->num_elements();
+  } else {
+    TF_RETURN_IF_ERROR(WaitReady());
+    DCHECK(IsReady());
+    DCHECK(num_elements != nullptr);
+
+    *num_elements = tensor_.NumElements();
+  }
+
+  return Status::OK();
+}
+
 Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) {
   if (!IsRemote()) {
     return errors::FailedPrecondition(
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 1bc9c65..e55f1a0 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -113,6 +113,7 @@
 
   Status NumDims(int* num_dims);
   Status Dim(int dim_index, int64* dim);
+  Status NumElements(int64* num_elements);
 
   // Return the op_id and output num if the handle refers to a remote tensor.
   Status RemoteAddress(int64* op_id, int32* output_num);
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 8486539..9871954 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -76,56 +76,47 @@
 namespace nodestats {
 inline int64 NowInNsec() { return Env::Default()->NowNanos(); }
 
-void SetScheduled(NodeExecStatsWrapper* stats, int64 micros) {
+void SetScheduled(NodeExecStatsInterface* stats, int64 micros) {
   if (!stats) return;
   stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
 }
 
-void SetAllStart(NodeExecStatsWrapper* stats) {
+void SetAllStart(NodeExecStatsInterface* stats) {
   if (!stats) return;
   stats->RecordExecutorStarted();
 }
 
-void SetOpStart(NodeExecStatsWrapper* stats) {
+void SetOpStart(NodeExecStatsInterface* stats) {
   if (!stats) return;
   stats->RecordComputeStarted();
 }
 
-void SetOpEnd(NodeExecStatsWrapper* stats) {
+void SetOpEnd(NodeExecStatsInterface* stats) {
   if (!stats) return;
   stats->RecordComputeEnded();
 }
 
-void SetAllEnd(NodeExecStatsWrapper* stats) {
+void SetAllEnd(NodeExecStatsInterface* stats) {
   if (!stats) return;
   stats->RecordExecutorEnded();
 }
 
-void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) {
+void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
   if (!stats) return;
   stats->SetOutput(slot, v);
 }
 
-void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) {
+void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
   if (!stats) return;
   stats->SetMemory(ctx);
 }
 
-void SetReferencedTensors(NodeExecStatsWrapper* stats,
+void SetReferencedTensors(NodeExecStatsInterface* stats,
                           const TensorReferenceVector& tensors) {
   if (!stats) return;
   stats->SetReferencedTensors(tensors);
 }
 
-// Sets the timeline_label field of *stats, using data from *node.
-// Returns true iff the node is a transfer node.
-bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) {
-  if (!stats) {
-    return false;
-  }
-  return stats->SetTimelineLabel(node);
-}
-
 }  // namespace nodestats
 
 class ExecutorImpl;
@@ -1301,7 +1292,7 @@
 
   // After item->kernel computation is done, processes its outputs.
   Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
-                        EntryVector* outputs, NodeExecStatsWrapper* stats);
+                        EntryVector* outputs, NodeExecStatsInterface* stats);
 
   // After processing the outputs, propagates the outputs to their dsts.
   // Contents of *outputs are left in an indeterminate state after
@@ -1312,7 +1303,7 @@
   // "node" just finishes. Takes ownership of "stats". Returns true if
   // execution has completed.
   bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
-                NodeExecStatsWrapper* stats,
+                NodeExecStatsInterface* stats,
                 TaggedNodeReadyQueue* inline_ready);
 
   // Schedule all the expensive nodes in 'ready', and put all the inexpensive
@@ -1513,7 +1504,7 @@
 struct ExecutorState::AsyncState {
   AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
              const NodeItem* _item, Entry* _first_input,
-             NodeExecStatsWrapper* _stats)
+             NodeExecStatsInterface* _stats)
       : saved_inputs(*p.inputs),
         saved_input_device_contexts(*p.input_device_contexts),
         saved_input_alloc_attrs(*p.input_alloc_attrs),
@@ -1538,7 +1529,7 @@
   const NodeItem* item;
   Entry* first_input;
   OpKernelContext ctx;
-  NodeExecStatsWrapper* stats;
+  NodeExecStatsInterface* stats;
 
  private:
   OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
@@ -1583,7 +1574,7 @@
   params.stats_collector = stats_collector_;
 
   Status s;
-  NodeExecStatsWrapper* stats = nullptr;
+  NodeExecStatsInterface* stats = nullptr;
   EntryVector outputs;
   bool completed = false;
   inline_ready.push_back(tagged_node);
@@ -1613,7 +1604,7 @@
     if (stats_collector_ && !tagged_node.is_dead) {
       // track allocations if and only if we are collecting statistics
       params.track_allocations = true;
-      stats = new NodeExecStatsWrapper(node->name());
+      stats = stats_collector_->CreateNodeExecStats(node);
       nodestats::SetScheduled(stats, scheduled_nsec);
       nodestats::SetAllStart(stats);
     }
@@ -1671,7 +1662,7 @@
 
         auto done = [this, state]() {
           Device* device = impl_->params_.device;
-          NodeExecStatsWrapper* stats = state->stats;  // Shorthand
+          NodeExecStatsInterface* stats = state->stats;  // Shorthand
           Entry* first_input = state->first_input;     // Shorthand
 
           nodestats::SetOpEnd(stats);
@@ -1862,7 +1853,7 @@
 
 Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
                                      EntryVector* outputs,
-                                     NodeExecStatsWrapper* stats) {
+                                     NodeExecStatsInterface* stats) {
   const Node* node = item.node;
   DCHECK_EQ(0, outputs->size());
   outputs->resize(item.num_outputs);
@@ -2080,16 +2071,15 @@
 
 bool ExecutorState::NodeDone(const Status& s, const Node* node,
                              const TaggedNodeSeq& ready,
-                             NodeExecStatsWrapper* stats,
+                             NodeExecStatsInterface* stats,
                              TaggedNodeReadyQueue* inline_ready) {
   nodestats::SetAllEnd(stats);
-  if (stats_collector_ != nullptr &&
-      !nodestats::SetTimelineLabel(node, stats)) {
-    // Only record non-transfer nodes.
-    // Transfers 'stats' ownership to 'stats_collector_'.
-    stats_collector_->Save(impl_->params_.device->name(), stats);
-  } else if (stats) {
-    delete stats;
+  if (stats) {
+    if (stats_collector_) {
+      stats->Done(impl_->params_.device->name());
+    } else {
+      delete stats;
+    }
   }
 
   bool abort_run = false;
@@ -2311,13 +2301,15 @@
   auto done_cb = std::move(done_cb_);
   auto runner = std::move(runner_);
   mu_.unlock();
-  if (sync_on_finish_ && status.ok()) {
+  Device* device = impl_->params_.device;
+  if ((sync_on_finish_ && status.ok()) || device->RequiresSyncOnCompletion()) {
     // Block until the device has finished all queued operations. For
     // devices like GPUs that continue to execute Ops after their Compute
     // methods have completed, this ensures that control is not returned to
     // the user until the step (and its side-effects) has actually completed.
-    status = impl_->params_.device->Sync();
+    status.Update(device->Sync());
   }
+
   delete this;
   CHECK(done_cb != nullptr);
   runner([=]() { done_cb(status); });
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 1c9b697..472865c 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -414,9 +414,8 @@
       device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
       &fbody->fdef.signature(), this, fbody->arg_types, input_memory_types,
       fbody->ret_types, output_memory_types, graph_def_version_, &s);
-  *kernel = new CallOp(handle, &construction);
-  if (!s.ok()) {
-    delete *kernel;
+  if (s.ok()) {
+    *kernel = new CallOp(handle, &construction);
   }
   return s;
 }
diff --git a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
index 636cd43..6bd29ef7 100644
--- a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
@@ -26,8 +26,12 @@
 class CUDAHostAllocator : public SubAllocator {
  public:
   // Note: stream_exec cannot be null.
-  explicit CUDAHostAllocator(se::StreamExecutor* stream_exec)
-      : stream_exec_(stream_exec) {
+  explicit CUDAHostAllocator(se::StreamExecutor* stream_exec, int numa_node,
+                             const std::vector<Visitor>& alloc_visitors,
+                             const std::vector<Visitor>& free_visitors)
+      : SubAllocator(alloc_visitors, free_visitors),
+        stream_exec_(stream_exec),
+        numa_node_(numa_node) {
     CHECK(stream_exec_ != nullptr);
   }
   ~CUDAHostAllocator() override {}
@@ -39,19 +43,23 @@
       if (ptr == nullptr) {
         LOG(WARNING) << "could not allocate pinned host memory of size: "
                      << num_bytes;
+        return ptr;
       }
+      VisitAlloc(ptr, numa_node_, num_bytes);
     }
     return ptr;
   }
 
   void Free(void* ptr, size_t num_bytes) override {
     if (ptr != nullptr) {
+      VisitFree(ptr, numa_node_, num_bytes);
       stream_exec_->HostMemoryDeallocate(ptr);
     }
   }
 
  private:
   se::StreamExecutor* stream_exec_;  // not owned, non-null
+  const int numa_node_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator);
 };
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
index 2d4c8d0..42021e5 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
@@ -22,18 +22,48 @@
 
 namespace tensorflow {
 
-GPUBFCAllocator::GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
-                                 const string& name)
-    : GPUBFCAllocator(cuda_gpu_id, total_memory, GPUOptions(), name) {}
+bool GPUBFCAllocator::GetAllowGrowthValue(const GPUOptions& gpu_options) {
+  const char* force_allow_growth_string =
+      std::getenv("TF_FORCE_GPU_ALLOW_GROWTH");
+  if (force_allow_growth_string == nullptr) {
+    return gpu_options.allow_growth();
+  }
 
-GPUBFCAllocator::GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
+  if (strcmp("false", force_allow_growth_string) == 0) {
+    if (gpu_options.allow_growth()) {
+      LOG(WARNING)
+          << "Overriding allow_growth setting because the"
+          << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original"
+          << " config value was " << gpu_options.allow_growth() << ".";
+    }
+    return false;
+  } else if (strcmp("true", force_allow_growth_string) == 0) {
+    if (!gpu_options.allow_growth()) {
+      LOG(WARNING)
+          << "Overriding allow_growth setting because the"
+          << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original"
+          << " config value was " << gpu_options.allow_growth() << ".";
+    }
+    return true;
+  }
+
+  LOG(ERROR)
+      << "The TF_FORCE_GPU_ALLOW_GROWTH environment variable is set but could"
+      << " not be parsed: \"" << force_allow_growth_string << "\". Valid"
+      << " values are \"true\" or \"false\". Using original config value"
+      << " of " << gpu_options.allow_growth() << ".";
+  return gpu_options.allow_growth();
+}
+
+GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator,
+                                 size_t total_memory, const string& name)
+    : GPUBFCAllocator(sub_allocator, total_memory, GPUOptions(), name) {}
+
+GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator,
+                                 size_t total_memory,
                                  const GPUOptions& gpu_options,
                                  const string& name)
-    : BFCAllocator(
-          new GPUMemAllocator(
-              GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(),
-              gpu_options.per_process_gpu_memory_fraction() > 1.0 ||
-                  gpu_options.experimental().use_unified_memory()),
-          total_memory, gpu_options.allow_growth(), name) {}
+    : BFCAllocator(sub_allocator, total_memory,
+                   GPUBFCAllocator::GetAllowGrowthValue(gpu_options), name) {}
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
index f1cc2ea..d4c9cee 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -31,28 +31,20 @@
 
 namespace tensorflow {
 
-// A GPU memory allocator that implements a 'best-fit with coalescing'
-// algorithm.
-class GPUBFCAllocator : public BFCAllocator {
- public:
-  // 'cuda_gpu_id' refers to the ID of the GPU device within
-  // the process and must reference a valid ID in the process.
-  GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
-                  const string& name);
-  GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
-                  const GPUOptions& gpu_options, const string& name);
-  virtual ~GPUBFCAllocator() {}
-
-  TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
-};
-
 // Suballocator for GPU memory.
 class GPUMemAllocator : public SubAllocator {
  public:
+  // 'platform_gpu_id' refers to the ID of the GPU device within
+  // the process and must reference a valid ID in the process.
   // Note: stream_exec cannot be null.
   explicit GPUMemAllocator(se::StreamExecutor* stream_exec,
-                           bool use_unified_memory)
-      : stream_exec_(stream_exec), use_unified_memory_(use_unified_memory) {
+                           PlatformGpuId gpu_id, bool use_unified_memory,
+                           const std::vector<Visitor>& alloc_visitors,
+                           const std::vector<Visitor>& free_visitors)
+      : SubAllocator(alloc_visitors, free_visitors),
+        stream_exec_(stream_exec),
+        gpu_id_(gpu_id),
+        use_unified_memory_(use_unified_memory) {
     CHECK(stream_exec_ != nullptr);
   }
   ~GPUMemAllocator() override {}
@@ -65,12 +57,14 @@
       } else {
         ptr = stream_exec_->AllocateArray<char>(num_bytes).opaque();
       }
+      VisitAlloc(ptr, gpu_id_.value(), num_bytes);
     }
     return ptr;
   }
 
   void Free(void* ptr, size_t num_bytes) override {
     if (ptr != nullptr) {
+      VisitFree(ptr, gpu_id_.value(), num_bytes);
       if (use_unified_memory_) {
         stream_exec_->UnifiedMemoryDeallocate(ptr);
       } else {
@@ -82,11 +76,28 @@
 
  private:
   se::StreamExecutor* stream_exec_;  // not owned, non-null
+  const PlatformGpuId gpu_id_;
   const bool use_unified_memory_ = false;
 
   TF_DISALLOW_COPY_AND_ASSIGN(GPUMemAllocator);
 };
 
+// A GPU memory allocator that implements a 'best-fit with coalescing'
+// algorithm.
+class GPUBFCAllocator : public BFCAllocator {
+ public:
+  GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory,
+                  const string& name);
+  GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory,
+                  const GPUOptions& gpu_options, const string& name);
+  ~GPUBFCAllocator() override {}
+
+  TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
+
+ private:
+  static bool GetAllowGrowthValue(const GPUOptions& gpu_options);
+};
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
index 67caeb3..60e82ed 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
@@ -21,6 +21,7 @@
 #include <vector>
 
 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
 #include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -46,7 +47,11 @@
 }
 
 TEST(GPUBFCAllocatorTest, NoDups) {
-  GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+  PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
   CheckStats(&a, 0, 0, 0, 0);
 
   // Allocate a lot of raw pointers
@@ -75,7 +80,11 @@
 }
 
 TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
-  GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+  PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
   // Allocate 256 raw pointers of sizes between 100 bytes and about
   // a meg
   random::PhiloxRandom philox(123, 17);
@@ -133,7 +142,11 @@
 }
 
 TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
-  GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+  PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
   CheckStats(&a, 0, 0, 0, 0);
 
   float* first_ptr = a.Allocate<float>(1024);
@@ -168,18 +181,30 @@
 }
 
 TEST(GPUBFCAllocatorTest, AllocateZeroBufSize) {
-  GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+  PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
   float* ptr = a.Allocate<float>(0);
   EXPECT_EQ(nullptr, ptr);
 }
 
 TEST(GPUBFCAllocatorTest, TracksSizes) {
-  GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+  PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
   EXPECT_EQ(true, a.TracksAllocationSizes());
 }
 
 TEST(GPUBFCAllocatorTest, AllocatedVsRequested) {
-  GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+  PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
   float* t1 = a.Allocate<float>(1);
   EXPECT_EQ(4, a.RequestedSize(t1));
   EXPECT_EQ(256, a.AllocatedSize(t1));
@@ -187,8 +212,12 @@
 }
 
 TEST(GPUBFCAllocatorTest, TestCustomMemoryLimit) {
+  PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
   // Configure a 1MiB byte limit
-  GPUBFCAllocator a(CudaGpuId(0), 1 << 20, "GPU_0_bfc");
+  GPUBFCAllocator a(sub_allocator, 1 << 20, "GPU_0_bfc");
 
   float* first_ptr = a.Allocate<float>(1 << 6);
   float* second_ptr = a.Allocate<float>(1 << 20);
@@ -203,7 +232,11 @@
   options.set_allow_growth(true);
 
   // Max of 2GiB, but starts out small.
-  GPUBFCAllocator a(CudaGpuId(0), 1LL << 31, options, "GPU_0_bfc");
+  PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUBFCAllocator a(sub_allocator, 1LL << 31, "GPU_0_bfc");
 
   // Allocate 10 raw pointers of sizes between 100 bytes and about
   // 64 megs.
@@ -264,8 +297,15 @@
 }
 
 TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) {
-  GPUBFCAllocator a(CudaGpuId(0), 1UL << 60, "GPU_0_bfc");
-  GPUBFCAllocator b(CudaGpuId(0), 1UL << 60, "GPU_0_bfc");
+  PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUBFCAllocator a(sub_allocator, 1UL << 60, "GPU_0_bfc");
+  sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUBFCAllocator b(sub_allocator, 1UL << 60, "GPU_0_bfc");
   void* amem = a.AllocateRaw(1, 1);
   void* bmem = b.AllocateRaw(1, 1 << 30);
   a.DeallocateRaw(amem);
@@ -273,7 +313,11 @@
 }
 
 static void BM_Allocation(int iters) {
-  GPUBFCAllocator a(CudaGpuId(0), 1uLL << 33, "GPU_0_bfc");
+  PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc");
   // Exercise a few different allocation sizes
   std::vector<size_t> sizes = {256,        4096,      16384,    524288,
                                512,        1048576,   10485760, 104857600,
@@ -289,7 +333,11 @@
 BENCHMARK(BM_Allocation);
 
 static void BM_AllocationThreaded(int iters, int num_threads) {
-  GPUBFCAllocator a(CudaGpuId(0), 1uLL << 33, "GPU_0_bfc");
+  PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc");
   thread::ThreadPool pool(Env::Default(), "test", num_threads);
   std::atomic_int_fast32_t count(iters);
   mutex done_lock;
@@ -325,7 +373,11 @@
 // A more complex benchmark that defers deallocation of an object for
 // "delay" allocations.
 static void BM_AllocationDelayed(int iters, int delay) {
-  GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+  PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
   // Exercise a few different allocation sizes
   std::vector<int> sizes = {256, 4096, 16384, 4096, 512, 1024, 1024};
   int size_index = 0;
@@ -358,12 +410,18 @@
 
 class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
  protected:
+  void SetUp() override { CHECK_EQ(unsetenv("TF_FORCE_GPU_ALLOW_GROWTH"), 0); }
+
   // The following test methods are called from tests. The reason for this is
   // that this class is a friend class to BFCAllocator, but tests are not, so
   // only methods inside this class can access private members of BFCAllocator.
 
   void TestBinDebugInfo() {
-    GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+    PlatformGpuId platform_gpu_id(0);
+    GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+        GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+        platform_gpu_id, false /*use_unified_memory*/, {}, {});
+    GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
 
     std::vector<void*> initial_ptrs;
     std::vector<size_t> initial_ptrs_allocated_sizes;
@@ -441,7 +499,11 @@
   }
 
   void TestLog2FloorNonZeroSlow() {
-    GPUBFCAllocator a(CudaGpuId(0), 1 /* total_memory */, "GPU_0_bfc");
+    PlatformGpuId platform_gpu_id(0);
+    GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+        GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+        platform_gpu_id, false /*use_unified_memory*/, {}, {});
+    GPUBFCAllocator a(sub_allocator, 1 /* total_memory */, "GPU_0_bfc");
     EXPECT_EQ(-1, a.Log2FloorNonZeroSlow(0));
     EXPECT_EQ(0, a.Log2FloorNonZeroSlow(1));
     EXPECT_EQ(1, a.Log2FloorNonZeroSlow(2));
@@ -450,6 +512,56 @@
     EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1024));
     EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1025));
   }
+
+  void TestForceAllowGrowth() {
+    PlatformGpuId platform_gpu_id(0);
+    GPUOptions options;
+    // Unset flag value uses provided option.
+    unsetenv("TF_FORCE_GPU_ALLOW_GROWTH");
+    options.set_allow_growth(true);
+    GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+        GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+        platform_gpu_id, false /*use_unified_memory*/, {}, {});
+    GPUBFCAllocator unset_flag_allocator(sub_allocator, 1LL << 31, options,
+                                         "GPU_0_bfc");
+    EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+              unset_flag_allocator.curr_region_allocation_bytes_);
+
+    // Unparseable flag value uses provided option.
+    setenv("TF_FORCE_GPU_ALLOW_GROWTH", "unparseable", 1);
+    options.set_allow_growth(true);
+    sub_allocator = new GPUMemAllocator(
+        GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+        platform_gpu_id, false /*use_unified_memory*/, {}, {});
+    GPUBFCAllocator unparsable_flag_allocator(sub_allocator, 1LL << 31, options,
+                                              "GPU_1_bfc");
+    EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+              unparsable_flag_allocator.curr_region_allocation_bytes_);
+
+    // Max of 2GiB total memory. Env variable set forces allow_growth, which
+    // does an initial allocation of 1MiB.
+    setenv("TF_FORCE_GPU_ALLOW_GROWTH", "true", 1);
+    options.set_allow_growth(false);
+    sub_allocator = new GPUMemAllocator(
+        GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+        platform_gpu_id, false /*use_unified_memory*/, {}, {});
+    GPUBFCAllocator force_allow_growth_allocator(sub_allocator, 1LL << 31,
+                                                 options, "GPU_2_bfc");
+    EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+              force_allow_growth_allocator.curr_region_allocation_bytes_);
+
+    // If env variable forces allow_growth disabled, all available memory is
+    // allocated.
+    setenv("TF_FORCE_GPU_ALLOW_GROWTH", "false", 1);
+    options.set_allow_growth(true);
+    sub_allocator = new GPUMemAllocator(
+        GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+        platform_gpu_id, false /*use_unified_memory*/, {}, {});
+    GPUBFCAllocator force_no_allow_growth_allocator(sub_allocator, 1LL << 31,
+                                                    options, "GPU_3_bfc");
+    EXPECT_EQ(GPUBFCAllocator::RoundedBytes(1LL << 31),
+              force_no_allow_growth_allocator.curr_region_allocation_bytes_);
+  }
 };
 
 TEST_F(GPUBFCAllocatorPrivateMethodsTest, BinDebugInfo) { TestBinDebugInfo(); }
@@ -458,6 +570,10 @@
   TestLog2FloorNonZeroSlow();
 }
 
+TEST_F(GPUBFCAllocatorPrivateMethodsTest, ForceAllowGrowth) {
+  TestForceAllowGrowth();
+}
+
 }  // namespace tensorflow
 
 #endif  // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
index 934a57a..d85ca88 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
@@ -27,10 +27,11 @@
 
 namespace tensorflow {
 
-GPUcudaMallocAllocator::GPUcudaMallocAllocator(VisitableAllocator* allocator,
-                                               CudaGpuId cuda_gpu_id)
+GPUcudaMallocAllocator::GPUcudaMallocAllocator(Allocator* allocator,
+                                               PlatformGpuId platform_gpu_id)
     : base_allocator_(allocator) {
-  stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+  stream_exec_ =
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
 }
 
 GPUcudaMallocAllocator::~GPUcudaMallocAllocator() { delete base_allocator_; }
@@ -60,14 +61,6 @@
 #endif  // GOOGLE_CUDA
 }
 
-void GPUcudaMallocAllocator::AddAllocVisitor(Visitor visitor) {
-  return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUcudaMallocAllocator::AddFreeVisitor(Visitor visitor) {
-  return base_allocator_->AddFreeVisitor(visitor);
-}
-
 bool GPUcudaMallocAllocator::TracksAllocationSizes() { return false; }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
index 856fdc3..8df3724 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
@@ -19,7 +19,7 @@
 #include <memory>
 
 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/stream_executor.h"
 #include "tensorflow/core/platform/types.h"
@@ -29,20 +29,18 @@
 // An allocator that wraps a GPU allocator and adds debugging
 // functionality that verifies that users do not write outside their
 // allocated memory.
-class GPUcudaMallocAllocator : public VisitableAllocator {
+class GPUcudaMallocAllocator : public Allocator {
  public:
-  explicit GPUcudaMallocAllocator(VisitableAllocator* allocator,
-                                  CudaGpuId cuda_gpu_id);
+  explicit GPUcudaMallocAllocator(Allocator* allocator,
+                                  PlatformGpuId platform_gpu_id);
   ~GPUcudaMallocAllocator() override;
   string Name() override { return "gpu_debug"; }
   void* AllocateRaw(size_t alignment, size_t num_bytes) override;
   void DeallocateRaw(void* ptr) override;
-  void AddAllocVisitor(Visitor visitor) override;
-  void AddFreeVisitor(Visitor visitor) override;
   bool TracksAllocationSizes() override;
 
  private:
-  VisitableAllocator* base_allocator_ = nullptr;  // owned
+  Allocator* base_allocator_ = nullptr;  // owned
 
   se::StreamExecutor* stream_exec_;  // Not owned.
 
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
index e4c834b..989ddbe 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
@@ -73,10 +73,11 @@
 // -----------------------------------------------------------------------------
 // GPUDebugAllocator
 // -----------------------------------------------------------------------------
-GPUDebugAllocator::GPUDebugAllocator(VisitableAllocator* allocator,
-                                     CudaGpuId cuda_gpu_id)
+GPUDebugAllocator::GPUDebugAllocator(Allocator* allocator,
+                                     PlatformGpuId platform_gpu_id)
     : base_allocator_(allocator) {
-  stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+  stream_exec_ =
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
 }
 
 GPUDebugAllocator::~GPUDebugAllocator() { delete base_allocator_; }
@@ -111,14 +112,6 @@
   base_allocator_->DeallocateRaw(ptr);
 }
 
-void GPUDebugAllocator::AddAllocVisitor(Visitor visitor) {
-  return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUDebugAllocator::AddFreeVisitor(Visitor visitor) {
-  return base_allocator_->AddFreeVisitor(visitor);
-}
-
 bool GPUDebugAllocator::TracksAllocationSizes() { return true; }
 
 size_t GPUDebugAllocator::RequestedSize(const void* ptr) {
@@ -158,10 +151,11 @@
 // -----------------------------------------------------------------------------
 // GPUNanResetAllocator
 // -----------------------------------------------------------------------------
-GPUNanResetAllocator::GPUNanResetAllocator(VisitableAllocator* allocator,
-                                           CudaGpuId cuda_gpu_id)
+GPUNanResetAllocator::GPUNanResetAllocator(Allocator* allocator,
+                                           PlatformGpuId platform_gpu_id)
     : base_allocator_(allocator) {
-  stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+  stream_exec_ =
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
 }
 
 GPUNanResetAllocator::~GPUNanResetAllocator() { delete base_allocator_; }
@@ -200,14 +194,6 @@
   base_allocator_->DeallocateRaw(ptr);
 }
 
-void GPUNanResetAllocator::AddAllocVisitor(Visitor visitor) {
-  return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUNanResetAllocator::AddFreeVisitor(Visitor visitor) {
-  return base_allocator_->AddFreeVisitor(visitor);
-}
-
 size_t GPUNanResetAllocator::RequestedSize(const void* ptr) {
   return base_allocator_->RequestedSize(ptr);
 }
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
index 0f9b720..17757a1 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
@@ -21,7 +21,7 @@
 #include <unordered_map>
 
 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/stream_executor.h"
 #include "tensorflow/core/platform/types.h"
@@ -31,16 +31,14 @@
 // An allocator that wraps a GPU allocator and adds debugging
 // functionality that verifies that users do not write outside their
 // allocated memory.
-class GPUDebugAllocator : public VisitableAllocator {
+class GPUDebugAllocator : public Allocator {
  public:
-  explicit GPUDebugAllocator(VisitableAllocator* allocator,
-                             CudaGpuId cuda_gpu_id);
+  explicit GPUDebugAllocator(Allocator* allocator,
+                             PlatformGpuId platform_gpu_id);
   ~GPUDebugAllocator() override;
   string Name() override { return "gpu_debug"; }
   void* AllocateRaw(size_t alignment, size_t num_bytes) override;
   void DeallocateRaw(void* ptr) override;
-  void AddAllocVisitor(Visitor visitor) override;
-  void AddFreeVisitor(Visitor visitor) override;
   bool TracksAllocationSizes() override;
   size_t RequestedSize(const void* ptr) override;
   size_t AllocatedSize(const void* ptr) override;
@@ -53,7 +51,7 @@
   bool CheckFooter(void* ptr);
 
  private:
-  VisitableAllocator* base_allocator_ = nullptr;  // owned
+  Allocator* base_allocator_ = nullptr;  // owned
 
   se::StreamExecutor* stream_exec_;  // Not owned.
 
@@ -63,23 +61,21 @@
 // An allocator that wraps a GPU allocator and resets the memory on
 // allocation and free to 'NaN', helping to identify cases where the
 // user forgets to initialize the memory.
-class GPUNanResetAllocator : public VisitableAllocator {
+class GPUNanResetAllocator : public Allocator {
  public:
-  explicit GPUNanResetAllocator(VisitableAllocator* allocator,
-                                CudaGpuId cuda_gpu_id);
+  explicit GPUNanResetAllocator(Allocator* allocator,
+                                PlatformGpuId platform_gpu_id);
   ~GPUNanResetAllocator() override;
   string Name() override { return "gpu_nan_reset"; }
   void* AllocateRaw(size_t alignment, size_t num_bytes) override;
   void DeallocateRaw(void* ptr) override;
-  void AddAllocVisitor(Visitor visitor) override;
-  void AddFreeVisitor(Visitor visitor) override;
   size_t RequestedSize(const void* ptr) override;
   size_t AllocatedSize(const void* ptr) override;
   void GetStats(AllocatorStats* stats) override;
   void ClearStats() override;
 
  private:
-  VisitableAllocator* base_allocator_ = nullptr;  // owned
+  Allocator* base_allocator_ = nullptr;  // owned
 
   se::StreamExecutor* stream_exec_;  // Not owned.
 
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
index 236a0af..aca08a7 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
@@ -34,10 +34,14 @@
 namespace {
 
 TEST(GPUDebugAllocatorTest, OverwriteDetection_None) {
-  const CudaGpuId cuda_gpu_id(0);
-  GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
-                      cuda_gpu_id);
-  auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+  const PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+                      platform_gpu_id);
+  auto stream_exec =
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
 
   for (int s : {8}) {
     std::vector<int64> cpu_array(s);
@@ -58,11 +62,14 @@
   for (int s : {8, 211}) {
     EXPECT_DEATH(
         {
-          const CudaGpuId cuda_gpu_id(0);
-          GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
-                              cuda_gpu_id);
+          const PlatformGpuId platform_gpu_id(0);
+          GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+              GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+              platform_gpu_id, false /*use_unified_memory*/, {}, {});
+          GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+                              platform_gpu_id);
           auto stream_exec =
-              GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+              GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
 
           std::vector<int64> cpu_array(s);
           memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
@@ -91,11 +98,14 @@
   for (int s : {8, 22}) {
     EXPECT_DEATH(
         {
-          const CudaGpuId cuda_gpu_id(0);
-          GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
-                              cuda_gpu_id);
+          const PlatformGpuId platform_gpu_id(0);
+          GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+              GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+              platform_gpu_id, false /*use_unified_memory*/, {}, {});
+          GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+                              platform_gpu_id);
           auto stream_exec =
-              GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+              GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
 
           std::vector<int64> cpu_array(s);
           memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
@@ -121,10 +131,14 @@
 }
 
 TEST(GPUDebugAllocatorTest, ResetToNan) {
-  const CudaGpuId cuda_gpu_id(0);
-  GPUNanResetAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
-                         cuda_gpu_id);
-  auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+  const PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUNanResetAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+                         platform_gpu_id);
+  auto stream_exec =
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
 
   std::vector<float> cpu_array(1024);
   std::vector<float> cpu_array_result(1024);
@@ -161,13 +175,17 @@
 }
 
 TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) {
-  const CudaGpuId cuda_gpu_id(0);
+  const PlatformGpuId platform_gpu_id(0);
   // NaN reset must be the outer-most allocator.
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
   GPUNanResetAllocator a(
-      new GPUDebugAllocator(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
-                            cuda_gpu_id),
-      cuda_gpu_id);
-  auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+      new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+                            platform_gpu_id),
+      platform_gpu_id);
+  auto stream_exec =
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
 
   std::vector<float> cpu_array(1024);
   std::vector<float> cpu_array_result(1024);
@@ -204,18 +222,24 @@
 }
 
 TEST(GPUDebugAllocatorTest, TracksSizes) {
-  const CudaGpuId cuda_gpu_id(0);
-  GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
-                      cuda_gpu_id);
+  const PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
+  GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+                      platform_gpu_id);
   EXPECT_EQ(true, a.TracksAllocationSizes());
 }
 
 TEST(GPUDebugAllocatorTest, AllocatedVsRequested) {
-  const CudaGpuId cuda_gpu_id(0);
+  const PlatformGpuId platform_gpu_id(0);
+  GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+      platform_gpu_id, false /*use_unified_memory*/, {}, {});
   GPUNanResetAllocator a(
-      new GPUDebugAllocator(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
-                            cuda_gpu_id),
-      cuda_gpu_id);
+      new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+                            platform_gpu_id),
+      platform_gpu_id);
   float* t1 = a.Allocate<float>(1);
   EXPECT_EQ(4, a.RequestedSize(t1));
   EXPECT_EQ(256, a.AllocatedSize(t1));
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 2763ac0..cf3faf6 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -41,7 +41,6 @@
 #include "tensorflow/core/common_runtime/gpu/gpu_util.h"
 #include "tensorflow/core/common_runtime/gpu_device_context.h"
 #include "tensorflow/core/common_runtime/local_device.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
 #include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/framework/device_base.h"
 #include "tensorflow/core/framework/op_kernel.h"
@@ -105,9 +104,9 @@
         reinterpret_cast<unsigned int*>(scratch + Eigen::kCudaScratchSize);
     stream_ = cuda_stream;
     allocator_ = alloc;
-    CudaGpuId cuda_gpu_id;
-    TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
-    device_prop_ = &Eigen::m_deviceProperties[cuda_gpu_id.value()];
+    PlatformGpuId platform_gpu_id;
+    TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+    device_prop_ = &Eigen::m_deviceProperties[platform_gpu_id.value()];
   }
 
   const cudaStream_t& stream() const override { return *stream_; }
@@ -285,6 +284,38 @@
   for (auto ctx : device_contexts_) ctx->Unref();
 }
 
+// This should be idempotent if already initialized.
+Status BaseGPUDevice::InitScratchBuffers() {
+  mutex_lock l(scratch_init_mutex_);
+  if (scratch_.size() < max_streams_) {
+    for (int i = 0; i < max_streams_; i++) {
+      DCHECK(streams_[i]);
+      if (scratch_.size() > i && scratch_[i]) continue;
+      size_t scratch_buffer_size =
+          Eigen::kCudaScratchSize + sizeof(unsigned int);
+      void* scratch_buffer = gpu_allocator_->AllocateRaw(
+          Allocator::kAllocatorAlignment, scratch_buffer_size);
+      if (scratch_buffer == nullptr) {
+        return errors::FailedPrecondition(
+            "Failed to allocate scratch buffer for device ",
+            tf_gpu_id_.value());
+      }
+      se::DeviceMemory<char> mem(
+          se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
+
+      bool ok = executor_->SynchronousMemZero(
+          &mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
+      if (!ok) {
+        return errors::FailedPrecondition(
+            "Failed to memcopy into scratch buffer for device ",
+            tf_gpu_id_.value());
+      }
+      scratch_.push_back(static_cast<char*>(scratch_buffer));
+    }
+  }
+  return Status::OK();
+}
+
 Status BaseGPUDevice::Init(const SessionOptions& options) {
   auto executor_status = GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id_);
   if (!executor_status.status().ok()) {
@@ -303,27 +334,6 @@
   for (int i = 0; i < max_streams_; i++) {
     streams_.push_back(StreamGroupFactory::Global().GetOrCreate(
         tf_gpu_id_, i, executor_, options.config.gpu_options()));
-
-    size_t scratch_buffer_size = Eigen::kCudaScratchSize + sizeof(unsigned int);
-    void* scratch_buffer = gpu_allocator_->AllocateRaw(
-        Allocator::kAllocatorAlignment, scratch_buffer_size);
-    if (scratch_buffer == nullptr) {
-      return errors::FailedPrecondition(
-          "Failed to allocate scratch buffer for device ", tf_gpu_id_.value());
-    }
-    scratch_.push_back(static_cast<char*>(scratch_buffer));
-
-    se::DeviceMemory<char> mem(
-        se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
-
-    bool ok = executor_->SynchronousMemZero(
-        &mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
-    if (!ok) {
-      return errors::FailedPrecondition(
-          "Failed to memcopy into scratch buffer for device ",
-          tf_gpu_id_.value());
-    }
-
     device_contexts_.push_back(new GPUDeviceContext(
         i, streams_.back()->compute, streams_.back()->host_to_device,
         streams_.back()->device_to_host, streams_.back()->device_to_device));
@@ -332,9 +342,10 @@
   gpu_device_info_->stream = streams_[0]->compute;
   gpu_device_info_->default_context = device_contexts_[0];
   gpu_device_info_->event_mgr = em_.get();
-  CudaGpuId cuda_gpu_id;
-  TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id));
-  gpu_device_info_->gpu_id = cuda_gpu_id.value();
+  PlatformGpuId platform_gpu_id;
+  TF_RETURN_IF_ERROR(
+      GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id));
+  gpu_device_info_->gpu_id = platform_gpu_id.value();
   set_tensorflow_gpu_device_info(gpu_device_info_);
 
   // Whether and how the GPU device uses its own threadpool.
@@ -690,9 +701,9 @@
   Eigen::GpuDevice device_;
 };
 
-// Parse 'visible_device_list' into a list of CUDA GPU ids.
+// Parse 'visible_device_list' into a list of platform GPU ids.
 Status ParseVisibleDeviceList(const string& visible_device_list,
-                              std::vector<CudaGpuId>* visible_gpu_order) {
+                              std::vector<PlatformGpuId>* visible_gpu_order) {
   visible_gpu_order->clear();
   se::Platform* gpu_manager = GPUMachineManager();
 
@@ -707,26 +718,28 @@
   } else {
     const std::vector<string> order_str =
         str_util::Split(visible_device_list, ',');
-    for (const string& cuda_gpu_id_str : order_str) {
-      int32 cuda_gpu_id;
-      if (!strings::safe_strto32(cuda_gpu_id_str, &cuda_gpu_id)) {
+    for (const string& platform_gpu_id_str : order_str) {
+      int32 platform_gpu_id;
+      if (!strings::safe_strto32(platform_gpu_id_str, &platform_gpu_id)) {
         return errors::InvalidArgument(
             "Could not parse entry in 'visible_device_list': '",
-            cuda_gpu_id_str, "'. visible_device_list = ", visible_device_list);
+            platform_gpu_id_str, "'. visible_device_list = ",
+            visible_device_list);
       }
-      if (cuda_gpu_id < 0 || cuda_gpu_id >= gpu_manager->VisibleDeviceCount()) {
+      if (platform_gpu_id < 0 ||
+          platform_gpu_id >= gpu_manager->VisibleDeviceCount()) {
         return errors::InvalidArgument(
-            "'visible_device_list' listed an invalid GPU id '", cuda_gpu_id,
+            "'visible_device_list' listed an invalid GPU id '", platform_gpu_id,
             "' but visible device count is ",
             gpu_manager->VisibleDeviceCount());
       }
-      visible_gpu_order->push_back(CudaGpuId(cuda_gpu_id));
+      visible_gpu_order->push_back(PlatformGpuId(platform_gpu_id));
     }
   }
 
   // Validate no repeats.
-  std::set<CudaGpuId> visible_device_set(visible_gpu_order->begin(),
-                                         visible_gpu_order->end());
+  std::set<PlatformGpuId> visible_device_set(visible_gpu_order->begin(),
+                                             visible_gpu_order->end());
   if (visible_device_set.size() != visible_gpu_order->size()) {
     return errors::InvalidArgument(
         "visible_device_list contained a duplicate entry: ",
@@ -737,8 +750,8 @@
 
 Status VerifyVirtualDeviceSettings(
     const size_t num_gpus_to_use, const GPUOptions& gpu_options,
-    const std::vector<CudaGpuId>& visible_gpu_order,
-    const std::vector<CudaGpuId>& valid_cuda_gpu_ids) {
+    const std::vector<PlatformGpuId>& visible_gpu_order,
+    const std::vector<PlatformGpuId>& valid_platform_gpu_ids) {
   const auto& virtual_devices = gpu_options.experimental().virtual_devices();
   CHECK(!virtual_devices.empty());
   if (gpu_options.per_process_gpu_memory_fraction() > 0) {
@@ -760,11 +773,11 @@
         " #GPUs in visible_device_list: ", visible_gpu_order.size(),
         " virtual_devices.size(): ", virtual_devices.size());
   }
-  if (valid_cuda_gpu_ids.size() != virtual_devices.size()) {
+  if (valid_platform_gpu_ids.size() != virtual_devices.size()) {
     return errors::Unknown(
         "The number of valid GPUs doesn't match the number of elements in "
         "the virtual_devices list.",
-        " #valid GPUs: ", valid_cuda_gpu_ids.size(),
+        " #valid GPUs: ", valid_platform_gpu_ids.size(),
         " virtual_devices.size(): ", virtual_devices.size());
   }
   return Status::OK();
@@ -806,18 +819,18 @@
 }
 
 // Get the memory limit for the virtual device being created on GPU with
-// 'cuda_gpu_id', when that virtual device is the only virtual device being
+// 'platform_gpu_id', when that virtual device is the only virtual device being
 // created on that GPU.
 Status SingleVirtualDeviceMemoryLimit(const GPUOptions& gpu_options,
-                                      CudaGpuId cuda_gpu_id,
+                                      PlatformGpuId platform_gpu_id,
                                       int64* memory_limit) {
   int64 total_memory = 0;
   int64 available_memory = 0;
   se::StreamExecutor* se =
-      GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
   if (!se->DeviceMemoryUsage(&available_memory, &total_memory)) {
     return errors::Unknown("Failed to query available memory for GPU ",
-                           cuda_gpu_id.value());
+                           platform_gpu_id.value());
   }
 
   int64 allocated_memory = 0;
@@ -867,10 +880,11 @@
   return new ConcretePerOpGpuDevice();
 }
 
-void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
-                                          PerOpGpuDevice* device,
-                                          DeviceContext* dc,
-                                          Allocator* allocator) {
+Status BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
+                                            PerOpGpuDevice* device,
+                                            DeviceContext* dc,
+                                            Allocator* allocator) {
+  TF_RETURN_IF_ERROR(InitScratchBuffers());
   if (dc) {
     const GPUDeviceContext* gpu_dc = static_cast<GPUDeviceContext*>(dc);
     const int stream_id = gpu_dc->stream_id();
@@ -881,6 +895,7 @@
   } else {
     ReinitializeDevice(context, device, 0, allocator);
   }
+  return Status::OK();
 }
 
 Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
@@ -916,8 +931,8 @@
     num_gpus_to_use = iter->second;
   }
   const auto& gpu_options = options.config.gpu_options();
-  std::vector<CudaGpuId> visible_gpu_order;
-  std::vector<CudaGpuId> valid_cuda_gpu_ids;
+  std::vector<PlatformGpuId> visible_gpu_order;
+  std::vector<PlatformGpuId> valid_platform_gpu_ids;
   // If we aren't going to use any GPUs, don't initialize them.
   // We don't want to call ParseVisibleDeviceList if num_gpus_to_use is 0,
   // because it treats an empty gpu_options.visible_device_list as 'all GPUs are
@@ -926,12 +941,12 @@
     TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(),
                                               &visible_gpu_order));
     TF_RETURN_IF_ERROR(
-        GetValidDeviceIds(visible_gpu_order, &valid_cuda_gpu_ids));
+        GetValidDeviceIds(visible_gpu_order, &valid_platform_gpu_ids));
   }
-  if (num_gpus_to_use > valid_cuda_gpu_ids.size()) {
-    num_gpus_to_use = valid_cuda_gpu_ids.size();
+  if (num_gpus_to_use > valid_platform_gpu_ids.size()) {
+    num_gpus_to_use = valid_platform_gpu_ids.size();
   }
-  if (!valid_cuda_gpu_ids.empty()) {
+  if (!valid_platform_gpu_ids.empty()) {
     // Save the original device.
     int original_device = 0;
     cudaError_t err = cudaGetDevice(&original_device);
@@ -941,17 +956,18 @@
     }
     // Force to implicitly initialize CUDA runtime on each valid GPU before
     // CreateGPUDevice().
-    for (CudaGpuId cuda_gpu_id : valid_cuda_gpu_ids) {
-      err = cudaSetDevice(cuda_gpu_id.value());
+    for (PlatformGpuId platform_gpu_id : valid_platform_gpu_ids) {
+      err = cudaSetDevice(platform_gpu_id.value());
       if (err != cudaSuccess) {
-        return errors::Internal("cudaSetDevice() on GPU:", cuda_gpu_id.value(),
-                                " failed. Status: ", cudaGetErrorString(err));
+        return errors::Internal("cudaSetDevice() on GPU:",
+                                platform_gpu_id.value(), " failed. Status: ",
+                                cudaGetErrorString(err));
       }
       err = cudaFree(nullptr);
       if (err != cudaSuccess) {
-        return errors::Internal(
-            "CUDA runtime implicit initialization on GPU:", cuda_gpu_id.value(),
-            " failed. Status: ", cudaGetErrorString(err));
+        return errors::Internal("CUDA runtime implicit initialization on GPU:",
+                                platform_gpu_id.value(), " failed. Status: ",
+                                cudaGetErrorString(err));
       }
     }
     // Reset to the original device.
@@ -977,10 +993,10 @@
     LOG(INFO) << line_buf;
     for (int i = 0; i < visible_gpu_order.size(); ++i) {
       line_buf = strings::StrCat(visible_gpu_order[i].value(), ":   ");
-      CudaGpuId cuda_id_i = visible_gpu_order[i];
+      PlatformGpuId gpu_id_i = visible_gpu_order[i];
       for (int j = 0; j < visible_gpu_order.size(); ++j) {
-        CudaGpuId cuda_id_j = visible_gpu_order[j];
-        if (im.directed_links.find({cuda_id_i, cuda_id_j}) !=
+        PlatformGpuId gpu_id_j = visible_gpu_order[j];
+        if (im.directed_links.find({gpu_id_i, gpu_id_j}) !=
             im.directed_links.end()) {
           line_buf.append("Y ");
         } else {
@@ -993,22 +1009,23 @@
 
   const auto& virtual_devices = gpu_options.experimental().virtual_devices();
   if (!virtual_devices.empty()) {
-    TF_RETURN_IF_ERROR(VerifyVirtualDeviceSettings(
-        num_gpus_to_use, gpu_options, visible_gpu_order, valid_cuda_gpu_ids));
+    TF_RETURN_IF_ERROR(VerifyVirtualDeviceSettings(num_gpus_to_use, gpu_options,
+                                                   visible_gpu_order,
+                                                   valid_platform_gpu_ids));
     // We've verified that num_gpus_to_use >= virtual_devices.size().
     num_gpus_to_use = virtual_devices.size();
     CHECK(gpu_options.visible_device_list().empty() ||
-          valid_cuda_gpu_ids == visible_gpu_order);
+          valid_platform_gpu_ids == visible_gpu_order);
   }
   int next_tf_gpu_id = 0;
   std::vector<int64> memory_limit_bytes;
   for (int i = 0; i < num_gpus_to_use; ++i) {
-    const CudaGpuId cuda_gpu_id = valid_cuda_gpu_ids[i];
+    const PlatformGpuId platform_gpu_id = valid_platform_gpu_ids[i];
     if (virtual_devices.empty() ||
         virtual_devices.Get(i).memory_limit_mb_size() == 0) {
       int64 single_virtual_device_memory_limit = 0;
       TF_RETURN_IF_ERROR(SingleVirtualDeviceMemoryLimit(
-          gpu_options, cuda_gpu_id, &single_virtual_device_memory_limit));
+          gpu_options, platform_gpu_id, &single_virtual_device_memory_limit));
       memory_limit_bytes.push_back(single_virtual_device_memory_limit);
     } else {
       const auto& memory_limit_mb = virtual_devices.Get(i).memory_limit_mb();
@@ -1021,7 +1038,7 @@
       TfGpuId tf_gpu_id(next_tf_gpu_id);
       ++next_tf_gpu_id;
       TF_RETURN_IF_ERROR(
-          GpuIdManager::InsertTfCudaGpuIdPair(tf_gpu_id, cuda_gpu_id));
+          GpuIdManager::InsertTfPlatformGpuIdPair(tf_gpu_id, platform_gpu_id));
     }
   }
   const int num_tf_gpus = next_tf_gpu_id;
@@ -1046,7 +1063,7 @@
   return Status::OK();
 }
 
-static string GetShortDeviceDescription(CudaGpuId cuda_gpu_id,
+static string GetShortDeviceDescription(PlatformGpuId platform_gpu_id,
                                         const se::DeviceDescription& desc) {
   int cc_major;
   int cc_minor;
@@ -1055,9 +1072,8 @@
     cc_minor = 0;
   }
   // LINT.IfChange
-  return strings::StrCat("device: ", cuda_gpu_id.value(),
-                         ", name: ", desc.name(),
-                         ", pci bus id: ", desc.pci_bus_id(),
+  return strings::StrCat("device: ", platform_gpu_id.value(), ", name: ",
+                         desc.name(), ", pci bus id: ", desc.pci_bus_id(),
                          ", compute capability: ", cc_major, ".", cc_minor);
   // LINT.ThenChange(//tensorflow/python/platform/test.py)
 }
@@ -1072,12 +1088,13 @@
   const string device_name =
       strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value());
   GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
-  CudaGpuId cuda_gpu_id;
-  TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+  PlatformGpuId platform_gpu_id;
+  TF_RETURN_IF_ERROR(
+      GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
   int numa_node = dev_locality.numa_node();
 
   se::StreamExecutor* se =
-      GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+      GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
   const se::DeviceDescription& desc = se->GetDeviceDescription();
   GPUProcessState* process_state = GPUProcessState::singleton();
   Allocator* gpu_allocator = process_state->GetGPUAllocator(
@@ -1098,11 +1115,11 @@
   // TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit.
   BaseGPUDevice* gpu_device = CreateGPUDevice(
       options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality,
-      tf_gpu_id, GetShortDeviceDescription(cuda_gpu_id, desc), gpu_allocator,
-      ProcessState::singleton()->GetCPUAllocator(numa_node));
+      tf_gpu_id, GetShortDeviceDescription(platform_gpu_id, desc),
+      gpu_allocator, ProcessState::singleton()->GetCPUAllocator(numa_node));
   LOG(INFO) << "Created TensorFlow device (" << device_name << " with "
             << (stats.bytes_limit >> 20) << " MB memory) -> physical GPU ("
-            << GetShortDeviceDescription(cuda_gpu_id, desc) << ")";
+            << GetShortDeviceDescription(platform_gpu_id, desc) << ")";
   TF_RETURN_IF_ERROR(gpu_device->Init(options));
   devices->push_back(gpu_device);
 
@@ -1110,18 +1127,21 @@
 }
 
 namespace {
-std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>>
+std::unique_ptr<std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>>
 GetPeerAccessMap(se::Platform* platform,
-                 const std::vector<CudaGpuId>& visible_gpu_order) {
-  std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>> map(
-      new std::map<std::pair<CudaGpuId, CudaGpuId>, bool>);
-  for (CudaGpuId cuda_gpu_i : visible_gpu_order) {
-    for (CudaGpuId cuda_gpu_j : visible_gpu_order) {
+                 const std::vector<PlatformGpuId>& visible_gpu_order) {
+  std::unique_ptr<std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>> map(
+      new std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>);
+  for (PlatformGpuId platform_gpu_i : visible_gpu_order) {
+    for (PlatformGpuId platform_gpu_j : visible_gpu_order) {
       se::StreamExecutor* from =
-          GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie();
+          GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_i)
+              .ValueOrDie();
       se::StreamExecutor* to =
-          GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie();
-      (*map)[{cuda_gpu_i, cuda_gpu_j}] = from->CanEnablePeerAccessTo(to);
+          GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_j)
+              .ValueOrDie();
+      (*map)[{platform_gpu_i, platform_gpu_j}] =
+          from->CanEnablePeerAccessTo(to);
     }
   }
 
@@ -1131,19 +1151,19 @@
 }  // namespace
 
 Status BaseGPUDeviceFactory::GetInterconnectMaps(
-    const std::vector<CudaGpuId>& visible_gpu_order, se::Platform* gpu_manager,
-    std::vector<InterconnectMap>* maps) {
+    const std::vector<PlatformGpuId>& visible_gpu_order,
+    se::Platform* gpu_manager, std::vector<InterconnectMap>* maps) {
   // The default interconnect map is obtained from the StreamExecutor.
   auto access_map = GetPeerAccessMap(gpu_manager, visible_gpu_order);
   maps->resize(1);
   InterconnectMap& imap = maps->at(0);
   imap.name = "StreamExecutor";
   imap.strength = InterconnectMap::kStreamExecutorStrength;
-  for (CudaGpuId cuda_id_i : visible_gpu_order) {
-    for (CudaGpuId cuda_id_j : visible_gpu_order) {
-      if (cuda_id_i == cuda_id_j) continue;
-      if ((*access_map)[{cuda_id_i, cuda_id_j}]) {
-        imap.directed_links.insert({cuda_id_i, cuda_id_j});
+  for (PlatformGpuId gpu_id_i : visible_gpu_order) {
+    for (PlatformGpuId gpu_id_j : visible_gpu_order) {
+      if (gpu_id_i == gpu_id_j) continue;
+      if ((*access_map)[{gpu_id_i, gpu_id_j}]) {
+        imap.directed_links.insert({gpu_id_i, gpu_id_j});
       }
     }
   }
@@ -1158,13 +1178,14 @@
     all_tf_gpu_ids.push_back(TfGpuId(i));
   }
   for (TfGpuId tf_gpu_id : all_tf_gpu_ids) {
-    CudaGpuId cuda_gpu_id;
-    TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+    PlatformGpuId platform_gpu_id;
+    TF_RETURN_IF_ERROR(
+        GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
     // Get GPU bus_id from its reported NUMA affinity.  Because GPUs are
     // virtualized in some environments, we can't just use the GPU id.
     // NUMA locales are indexed from 0, buses are indexed from 1.
     se::StreamExecutor* se =
-        GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+        GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
     const se::DeviceDescription& desc = se->GetDeviceDescription();
     int numa_node = desc.numa_node();
     if (numa_node < 0) {
@@ -1174,7 +1195,8 @@
       // may run into trouble later with data transfer operations.  The
       // trouble may manifest as slower than expected performance, or
       // outright failures.
-      LOG(INFO) << "Could not identify NUMA node of CUDA gpu id " << cuda_gpu_id
+      LOG(INFO) << "Could not identify NUMA node of platform GPU id "
+                << platform_gpu_id
                 << ", defaulting to 0.  Your kernel may not have been built "
                 << "with NUMA support.";
       numa_node = 0;
@@ -1187,10 +1209,10 @@
     LocalLinks* links = dev_locality.mutable_links();
     for (const InterconnectMap& imap : interconnects) {
       for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) {
-        CudaGpuId cuda_gpu_dst;
+        PlatformGpuId platform_gpu_dst;
         TF_RETURN_IF_ERROR(
-            GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst));
-        if (imap.directed_links.find({cuda_gpu_id, cuda_gpu_dst}) !=
+            GpuIdManager::TfToPlatformGpuId(tf_gpu_dst, &platform_gpu_dst));
+        if (imap.directed_links.find({platform_gpu_id, platform_gpu_dst}) !=
             imap.directed_links.end()) {
           InterconnectLink* ilink = links->add_link();
           ilink->set_device_id(tf_gpu_dst.value());
@@ -1204,10 +1226,10 @@
     // add high strength links to the others.
     for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) {
       if (tf_gpu_id == tf_gpu_dst) continue;
-      CudaGpuId cuda_gpu_dst;
+      PlatformGpuId platform_gpu_dst;
       TF_RETURN_IF_ERROR(
-          GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst));
-      if (cuda_gpu_id == cuda_gpu_dst) {
+          GpuIdManager::TfToPlatformGpuId(tf_gpu_dst, &platform_gpu_dst));
+      if (platform_gpu_id == platform_gpu_dst) {
         InterconnectLink* ilink = links->add_link();
         ilink->set_device_id(tf_gpu_dst.value());
         ilink->set_type("SAME_DEVICE");
@@ -1216,9 +1238,9 @@
     }
 
     (*localities)[tf_gpu_id] = dev_locality;
-    VLOG(1) << "GPUDevice CudaGpuId " << cuda_gpu_id << " TfGpuId " << tf_gpu_id
-            << " on bus " << dev_locality.bus_id() << " numa: " << numa_node
-            << " pci: " << desc.pci_bus_id()
+    VLOG(1) << "GPUDevice PlatformGpuId " << platform_gpu_id << " TfGpuId "
+            << tf_gpu_id << " on bus " << dev_locality.bus_id()
+            << " numa: " << numa_node << " pci: " << desc.pci_bus_id()
             << " DeviceLocality: " << dev_locality.DebugString();
   }
   return Status::OK();
@@ -1226,14 +1248,14 @@
 
 static int GetDefaultMinGPUMultiprocessorCount(
     se::Platform* gpu_manager,
-    const std::vector<CudaGpuId>& visible_gpu_order) {
+    const std::vector<PlatformGpuId>& visible_gpu_order) {
   static const int kDefaultMinGPUMultiprocessorCount = 8;
 
   // Find the highest multi-processor count across all visible GPUs.
   int max_count = -1;
   for (int i = 0; i < visible_gpu_order.size(); ++i) {
     auto exec_status =
-        GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, visible_gpu_order[i]);
+        GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_order[i]);
     if (!exec_status.ok()) {
       continue;
     }
@@ -1252,7 +1274,7 @@
 
 static int GetMinGPUMultiprocessorCount(
     se::Platform* gpu_manager,
-    const std::vector<CudaGpuId>& visible_gpu_order) {
+    const std::vector<PlatformGpuId>& visible_gpu_order) {
   const char* tf_min_gpu_core_count = getenv("TF_MIN_GPU_MULTIPROCESSOR_COUNT");
 
   if (tf_min_gpu_core_count == nullptr ||
@@ -1330,18 +1352,20 @@
 }
 
 Status EnablePeerAccess(se::Platform* platform,
-                        const std::vector<CudaGpuId>& visible_gpu_order) {
+                        const std::vector<PlatformGpuId>& visible_gpu_order) {
   int possible_peer_count = 0;
   int enabled_peer_count = 0;
   for (int i = 0; i < visible_gpu_order.size(); ++i) {
-    const CudaGpuId cuda_gpu_i = visible_gpu_order[i];
+    const PlatformGpuId platform_gpu_i = visible_gpu_order[i];
     for (int j = 0; j < visible_gpu_order.size(); ++j) {
-      const CudaGpuId cuda_gpu_j = visible_gpu_order[j];
+      const PlatformGpuId platform_gpu_j = visible_gpu_order[j];
       // We have already validated that ExecutorForDevice() calls return OK.
       se::StreamExecutor* from =
-          GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie();
+          GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_i)
+              .ValueOrDie();
       se::StreamExecutor* to =
-          GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie();
+          GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_j)
+              .ValueOrDie();
 
       if (from->CanEnablePeerAccessTo(to)) {
         ++possible_peer_count;
@@ -1349,7 +1373,8 @@
         if (!status.ok()) {
           LOG(WARNING)
               << "Unable to enable peer access between device ordinals "
-              << cuda_gpu_i << " and " << cuda_gpu_j << ", status: " << status;
+              << platform_gpu_i << " and " << platform_gpu_j
+              << ", status: " << status;
         } else {
           ++enabled_peer_count;
         }
@@ -1372,22 +1397,23 @@
 }  // namespace
 
 Status BaseGPUDeviceFactory::GetValidDeviceIds(
-    const std::vector<CudaGpuId>& visible_gpu_order,
-    std::vector<CudaGpuId>* ids) {
+    const std::vector<PlatformGpuId>& visible_gpu_order,
+    std::vector<PlatformGpuId>* ids) {
   se::Platform* gpu_manager = GPUMachineManager();
   bool new_gpu_found = false;
   for (int i = 0; i < visible_gpu_order.size(); ++i) {
-    const CudaGpuId cuda_gpu_id = visible_gpu_order[i];
+    const PlatformGpuId visible_gpu_id = visible_gpu_order[i];
 
-    // Only perform this once per visible cuda gpu id.
-    if (visible_gpu_initialized_[cuda_gpu_id.value()]) {
+    // Only perform this once per visible platform gpu id.
+    if (visible_gpu_initialized_[visible_gpu_id.value()]) {
       continue;
     }
 
-    visible_gpu_initialized_[cuda_gpu_id.value()] = true;
+    visible_gpu_initialized_[visible_gpu_id.value()] = true;
     new_gpu_found = true;
 
-    auto executor = GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, cuda_gpu_id);
+    auto executor =
+        GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_id);
     if (!executor.ok()) {
       return executor.status();
     }
@@ -1435,9 +1461,9 @@
 
   // Filter out devices that don't have the right capability or power.
   for (int i = 0; i < visible_gpu_order.size(); ++i) {
-    const CudaGpuId visible_gpu_id = visible_gpu_order[i];
+    const PlatformGpuId visible_gpu_id = visible_gpu_order[i];
     auto exec_status =
-        GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, visible_gpu_id);
+        GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_id);
     if (!exec_status.ok()) {
       LOG(INFO) << "Ignoring visible gpu device " << visible_gpu_id
                 << " whose executor is in invalid state: "
@@ -1486,7 +1512,7 @@
   if (!ids->empty()) {
     std::vector<int> raw_ids(ids->size());
     std::transform(ids->begin(), ids->end(), raw_ids.begin(),
-                   [](CudaGpuId id) -> int { return id.value(); });
+                   [](PlatformGpuId id) -> int { return id.value(); });
     LOG(INFO) << "Adding visible gpu devices: "
               << str_util::Join(raw_ids, ", ");
   }
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index 56d03d7..b25fe86 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -86,15 +86,16 @@
   // The caller owns the returned device.
   PerOpGpuDevice* MakeGpuDevice() override;
 
-  void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
-                             DeviceContext* dc, Allocator* allocator) override;
+  Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
+                               DeviceContext* dc,
+                               Allocator* allocator) override;
 
-  // Returns the CUDA GPU id of this device within the native driver system;
+  // Returns the platform GPU id of this device within the native driver system;
   // e.g., for CUDA this is the ordinal of the GPU within the system.
   int gpu_id() const {
-    CudaGpuId cuda_gpu_id;
-    TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id));
-    return cuda_gpu_id.value();
+    PlatformGpuId platform_gpu_id;
+    TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id));
+    return platform_gpu_id.value();
   }
 
   // The executor that provides control for the device; e.g., for CUDA this
@@ -125,6 +126,7 @@
   class StreamGroupFactory;
 
   gtl::InlinedVector<StreamGroup*, 4> streams_;
+  mutex scratch_init_mutex_;
   gtl::InlinedVector<char*, 4> scratch_;
   std::vector<GPUDeviceContext*> device_contexts_;
   GpuDeviceInfo* gpu_device_info_ = nullptr;
@@ -135,6 +137,9 @@
   std::unique_ptr<EventMgr> em_;
   std::unique_ptr<thread::ThreadPool> thread_pool_;
 
+  // Initialize scractch buffers used by Eigen.
+  Status InitScratchBuffers();
+
   void ReinitializeDevice(OpKernelContext* context, PerOpGpuDevice* device,
                           int stream_id, Allocator* allocator);
 
@@ -168,14 +173,14 @@
     int32 strength;
     static const int kSameDeviceStrength;
     static const int kStreamExecutorStrength;
-    std::set<std::pair<CudaGpuId, CudaGpuId>> directed_links;
+    std::set<std::pair<PlatformGpuId, PlatformGpuId>> directed_links;
   };
 
  protected:
   // Populates *maps with interconnect maps for all local direct access
   // pathways between GPUs.
   virtual Status GetInterconnectMaps(
-      const std::vector<CudaGpuId>& visible_gpu_order,
+      const std::vector<PlatformGpuId>& visible_gpu_order,
       se::Platform* gpu_manager, std::vector<InterconnectMap>* maps);
 
   struct TfGpuIdHash {
@@ -207,16 +212,16 @@
                                          Allocator* gpu_allocator,
                                          Allocator* cpu_allocator) = 0;
 
-  // Returns into 'ids' the list of valid CUDA GPU ids, in the order that
+  // Returns into 'ids' the list of valid platform GPU ids, in the order that
   // they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc,
   // based upon 'visible_gpu_order' which was generated by parsing
   // GPUOptions::visible_device_list which is a comma-separated list of CUDA GPU
   // ids.
-  Status GetValidDeviceIds(const std::vector<CudaGpuId>& visible_gpu_order,
-                           std::vector<CudaGpuId>* ids);
+  Status GetValidDeviceIds(const std::vector<PlatformGpuId>& visible_gpu_order,
+                           std::vector<PlatformGpuId>* ids);
 
-  // visible_gpu_initialized_[cuda_gpu_id] is true if visible GPU cuda_gpu_id
-  // has been initialized by the process.
+  // visible_gpu_initialized_[platform_gpu_id] is true if visible GPU
+  // platform_gpu_id has been initialized by the process.
   std::unordered_map<int, bool> visible_gpu_initialized_;
 };
 
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index daf59f0..3629409 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -30,18 +30,21 @@
 namespace {
 const char* kDeviceNamePrefix = "/job:localhost/replica:0/task:0";
 
-int64 GetTotalGPUMemory(CudaGpuId gpu_id) {
+int64 GetTotalGPUMemory(PlatformGpuId gpu_id) {
   se::StreamExecutor* se =
-      GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie();
+      GpuIdUtil::ExecutorForPlatformGpuId(GPUMachineManager(), gpu_id)
+          .ValueOrDie();
 
   int64 total_memory, available_memory;
   CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));
   return total_memory;
 }
 
-Status GetComputeCapability(CudaGpuId gpu_id, int* cc_major, int* cc_minor) {
+Status GetComputeCapability(PlatformGpuId gpu_id, int* cc_major,
+                            int* cc_minor) {
   se::StreamExecutor* se =
-      GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie();
+      GpuIdUtil::ExecutorForPlatformGpuId(GPUMachineManager(), gpu_id)
+          .ValueOrDie();
   if (!se->GetDeviceDescription().cuda_compute_capability(cc_major, cc_minor)) {
     *cc_major = 0;
     *cc_minor = 0;
@@ -223,7 +226,7 @@
 // error.
 TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) {
   int cc_major, cc_minor;
-  TF_ASSERT_OK(GetComputeCapability(CudaGpuId(0), &cc_major, &cc_minor));
+  TF_ASSERT_OK(GetComputeCapability(PlatformGpuId(0), &cc_major, &cc_minor));
   // Exit early while running on Pascal or later GPUs.
   if (cc_major >= 6) {
     return;
@@ -244,10 +247,10 @@
 // more memory than what is available on the device.
 TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
   static constexpr double kGpuMemoryFraction = 1.2;
-  static constexpr CudaGpuId kCudaGpuId(0);
+  static constexpr PlatformGpuId kPlatformGpuId(0);
 
   int cc_major, cc_minor;
-  TF_ASSERT_OK(GetComputeCapability(kCudaGpuId, &cc_major, &cc_minor));
+  TF_ASSERT_OK(GetComputeCapability(kPlatformGpuId, &cc_major, &cc_minor));
   // Exit early if running on pre-Pascal GPUs.
   if (cc_major < 6) {
     LOG(INFO)
@@ -262,7 +265,7 @@
   ASSERT_EQ(1, devices.size());
 
   int64 memory_limit = devices[0]->attributes().memory_limit();
-  ASSERT_EQ(memory_limit, static_cast<int64>(GetTotalGPUMemory(kCudaGpuId) *
+  ASSERT_EQ(memory_limit, static_cast<int64>(GetTotalGPUMemory(kPlatformGpuId) *
                                              kGpuMemoryFraction));
 
   AllocatorAttributes allocator_attributes = AllocatorAttributes();
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id.h b/tensorflow/core/common_runtime/gpu/gpu_id.h
index 2a6caea..f0d9022 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id.h
@@ -25,10 +25,10 @@
 //   physical machine, it can be filtered by CUDA environment variable
 //   CUDA_VISIBLE_DEVICES. Note that this id is not visible to Tensorflow, but
 //   result after filtering by CUDA_VISIBLE_DEVICES is visible to TF and is
-//   called CUDA GPU id as below. See
+//   called platform GPU id as below. See
 //   http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
 //   for more details.
-// - CUDA GPU id (also called *visible* GPU id in
+// - *platform* GPU id (also called *visible* GPU id in
 //   third_party/tensorflow/core/protobuf/config.proto): this is the id that is
 //   visible to Tensorflow after filtering by CUDA_VISIBLE_DEVICES, and is
 //   generated by the CUDA GPU driver. It starts from 0 and is used for CUDA API
@@ -39,14 +39,14 @@
 //   field of the device name "/device:GPU:<id>", and is also the identifier of
 //   a BaseGPUDevice. Note that the configuration allows us to create multiple
 //   BaseGPUDevice per GPU hardware in order to use multi CUDA streams on the
-//   hardware, so the mapping between TF GPU id and CUDA GPU id is not a 1:1
+//   hardware, so the mapping between TF GPU id and platform GPU id is not a 1:1
 //   mapping, see the example below.
 //
 // For example, assuming that in the machine we have GPU device with index 0, 1,
 // 2 and 3 (physical GPU id). Setting "CUDA_VISIBLE_DEVICES=1,2,3" will create
-// the following mapping between CUDA GPU id and physical GPU id:
+// the following mapping between platform GPU id and physical GPU id:
 //
-//        CUDA GPU id ->  physical GPU id
+//        platform GPU id ->  physical GPU id
 //                 0  ->  1
 //                 1  ->  2
 //                 2  ->  3
@@ -56,32 +56,32 @@
 //
 // Assuming we configure the Session to create one BaseGPUDevice per GPU
 // hardware, then setting GPUOptions::visible_device_list to "2,0" will create
-// the following mappting between TF GPU id and CUDA GPU id:
+// the following mappting between TF GPU id and platform GPU id:
 //
-//                  TF GPU id  ->  CUDA GPU ID
+//                  TF GPU id  ->  platform GPU ID
 //      0 (i.e. /device:GPU:0) ->  2
 //      1 (i.e. /device:GPU:1) ->  0
 //
-// Note that CUDA GPU id 1 is filtered out by GPUOptions::visible_device_list,
-// so it won't be used by the TF process.
+// Note that platform GPU id 1 is filtered out by
+// GPUOptions::visible_device_list, so it won't be used by the TF process.
 //
 // On the other hand, if we configure it to create 2 BaseGPUDevice per GPU
 // hardware, then setting GPUOptions::visible_device_list to "2,0" will create
-// the following mappting between TF GPU id and CUDA GPU id:
+// the following mappting between TF GPU id and platform GPU id:
 //
-//                  TF GPU id  ->  CUDA GPU ID
+//                  TF GPU id  ->  platform GPU ID
 //      0 (i.e. /device:GPU:0) ->  2
 //      1 (i.e. /device:GPU:1) ->  2
 //      2 (i.e. /device:GPU:2) ->  0
 //      3 (i.e. /device:GPU:3) ->  0
 //
-// We create strong-typed integer classes for both TF GPU id and CUDA GPU id to
-// minimize programming errors and improve code readability. Except for the
+// We create strong-typed integer classes for both TF GPU id and platform GPU id
+// to minimize programming errors and improve code readability. Except for the
 // StreamExecutor interface (as we don't change its API), whenever we need a
-// TF GPU id (or CUDA GPU id) we should use TfGpuId (or CudaGpuId) instead of a
-// raw integer.
+// TF GPU id (or platform GPU id) we should use TfGpuId (or PlatformGpuId)
+// instead of a raw integer.
 TF_LIB_GTL_DEFINE_INT_TYPE(TfGpuId, int32);
-TF_LIB_GTL_DEFINE_INT_TYPE(CudaGpuId, int32);
+TF_LIB_GTL_DEFINE_INT_TYPE(PlatformGpuId, int32);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
index b5099dc..2b40730 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
@@ -26,26 +26,27 @@
 
 namespace tensorflow {
 namespace {
-// Manages the map between TfGpuId and CUDA GPU id.
-class TfToCudaGpuIdMap {
+// Manages the map between TfGpuId and platform GPU id.
+class TfToPlatformGpuIdMap {
  public:
-  static TfToCudaGpuIdMap* singleton() {
-    static auto* id_map = new TfToCudaGpuIdMap;
+  static TfToPlatformGpuIdMap* singleton() {
+    static auto* id_map = new TfToPlatformGpuIdMap;
     return id_map;
   }
 
-  Status Insert(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id) LOCKS_EXCLUDED(mu_) {
+  Status Insert(TfGpuId tf_gpu_id, PlatformGpuId platform_gpu_id)
+      LOCKS_EXCLUDED(mu_) {
     std::pair<IdMapType::iterator, bool> result;
     {
       mutex_lock lock(mu_);
-      result = id_map_.insert({tf_gpu_id.value(), cuda_gpu_id.value()});
+      result = id_map_.insert({tf_gpu_id.value(), platform_gpu_id.value()});
     }
-    if (!result.second && cuda_gpu_id.value() != result.first->second) {
+    if (!result.second && platform_gpu_id.value() != result.first->second) {
       return errors::AlreadyExists(
           "TensorFlow device (GPU:", tf_gpu_id.value(),
           ") is being mapped to "
           "multiple CUDA devices (",
-          cuda_gpu_id.value(), " now, and ", result.first->second,
+          platform_gpu_id.value(), " now, and ", result.first->second,
           " previously), which is not supported. "
           "This may be the result of providing different GPU configurations "
           "(ConfigProto.gpu_options, for example different visible_device_list)"
@@ -56,17 +57,17 @@
     return Status::OK();
   }
 
-  bool Find(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) const
+  bool Find(TfGpuId tf_gpu_id, PlatformGpuId* platform_gpu_id) const
       LOCKS_EXCLUDED(mu_) {
     mutex_lock lock(mu_);
     auto result = id_map_.find(tf_gpu_id.value());
     if (result == id_map_.end()) return false;
-    *cuda_gpu_id = result->second;
+    *platform_gpu_id = result->second;
     return true;
   }
 
  private:
-  TfToCudaGpuIdMap() = default;
+  TfToPlatformGpuIdMap() = default;
 
   void TestOnlyReset() LOCKS_EXCLUDED(mu_) {
     mutex_lock lock(mu_);
@@ -78,17 +79,18 @@
   IdMapType id_map_ GUARDED_BY(mu_);
 
   friend class ::tensorflow::GpuIdManager;
-  TF_DISALLOW_COPY_AND_ASSIGN(TfToCudaGpuIdMap);
+  TF_DISALLOW_COPY_AND_ASSIGN(TfToPlatformGpuIdMap);
 };
 }  // namespace
 
-Status GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id,
-                                           CudaGpuId cuda_gpu_id) {
-  return TfToCudaGpuIdMap::singleton()->Insert(tf_gpu_id, cuda_gpu_id);
+Status GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId tf_gpu_id,
+                                               PlatformGpuId platform_gpu_id) {
+  return TfToPlatformGpuIdMap::singleton()->Insert(tf_gpu_id, platform_gpu_id);
 }
 
-Status GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) {
-  if (TfToCudaGpuIdMap::singleton()->Find(tf_gpu_id, cuda_gpu_id)) {
+Status GpuIdManager::TfToPlatformGpuId(TfGpuId tf_gpu_id,
+                                       PlatformGpuId* platform_gpu_id) {
+  if (TfToPlatformGpuIdMap::singleton()->Find(tf_gpu_id, platform_gpu_id)) {
     return Status::OK();
   }
   return errors::NotFound("TensorFlow device GPU:", tf_gpu_id.value(),
@@ -96,7 +98,7 @@
 }
 
 void GpuIdManager::TestOnlyReset() {
-  TfToCudaGpuIdMap::singleton()->TestOnlyReset();
+  TfToPlatformGpuIdMap::singleton()->TestOnlyReset();
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
index 491d92c..62df431 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
@@ -21,15 +21,17 @@
 
 namespace tensorflow {
 
-// Class that maintains a map from TfGpuId to CudaGpuId, and manages the
+// Class that maintains a map from TfGpuId to PlatformGpuId, and manages the
 // translation between them.
 class GpuIdManager {
  public:
-  // Adds a mapping from tf_gpu_id to cuda_gpu_id.
-  static Status InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id);
+  // Adds a mapping from tf_gpu_id to platform_gpu_id.
+  static Status InsertTfPlatformGpuIdPair(TfGpuId tf_gpu_id,
+                                          PlatformGpuId platform_gpu_id);
 
-  // Gets the cuda_gpu_id associated with tf_gpu_id. Returns OK if found.
-  static Status TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id);
+  // Gets the platform_gpu_id associated with tf_gpu_id. Returns OK if found.
+  static Status TfToPlatformGpuId(TfGpuId tf_gpu_id,
+                                  PlatformGpuId* platform_gpu_id);
 
   // Clears the map. Used in unit tests only.
   static void TestOnlyReset();
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
index a663ec7..8bf3c6a 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
@@ -22,38 +22,38 @@
 namespace tensorflow {
 namespace {
 
-CudaGpuId TfToCudaGpuId(TfGpuId tf) {
-  CudaGpuId cuda;
-  TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf, &cuda));
-  return cuda;
+PlatformGpuId TfToPlatformGpuId(TfGpuId tf) {
+  PlatformGpuId platform_gpu_id;
+  TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf, &platform_gpu_id));
+  return platform_gpu_id;
 }
 
 TEST(GpuIdManagerTest, Basics) {
   TfGpuId key_0(0);
-  CudaGpuId value_0(0);
-  TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0));
-  EXPECT_EQ(value_0, TfToCudaGpuId(key_0));
+  PlatformGpuId value_0(0);
+  TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_0, value_0));
+  EXPECT_EQ(value_0, TfToPlatformGpuId(key_0));
 
   // Multiple calls to map the same value is ok.
-  TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0));
-  EXPECT_EQ(value_0, TfToCudaGpuId(key_0));
+  TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_0, value_0));
+  EXPECT_EQ(value_0, TfToPlatformGpuId(key_0));
 
   // Map a different TfGpuId to a different value.
   TfGpuId key_1(3);
-  CudaGpuId value_1(2);
-  TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_1, value_1));
-  EXPECT_EQ(value_1, TfToCudaGpuId(key_1));
+  PlatformGpuId value_1(2);
+  TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_1, value_1));
+  EXPECT_EQ(value_1, TfToPlatformGpuId(key_1));
 
   // Mapping a different TfGpuId to the same value is ok.
   TfGpuId key_2(10);
-  TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_1));
-  EXPECT_EQ(value_1, TfToCudaGpuId(key_2));
+  TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_2, value_1));
+  EXPECT_EQ(value_1, TfToPlatformGpuId(key_2));
 
   // Mapping the same TfGpuId to a different value.
-  ASSERT_FALSE(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_0).ok());
+  ASSERT_FALSE(GpuIdManager::InsertTfPlatformGpuIdPair(key_2, value_0).ok());
 
   // Getting a nonexistent mapping.
-  ASSERT_FALSE(GpuIdManager::TfToCudaGpuId(TfGpuId(100), &value_0).ok());
+  ASSERT_FALSE(GpuIdManager::TfToPlatformGpuId(TfGpuId(100), &value_0).ok());
 }
 
 }  // namespace
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
index b9c66b3..b1f10fb 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
@@ -24,34 +24,37 @@
 
 namespace tensorflow {
 
-// Utility methods for translation between Tensorflow GPU ids and CUDA GPU ids.
+// Utility methods for translation between Tensorflow GPU ids and platform GPU
+// ids.
 class GpuIdUtil {
  public:
   // Convenient methods for getting the associated executor given a TfGpuId or
-  // CudaGpuId.
-  static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId(
-      se::Platform* gpu_manager, CudaGpuId cuda_gpu_id) {
-    return gpu_manager->ExecutorForDevice(cuda_gpu_id.value());
+  // PlatformGpuId.
+  static se::port::StatusOr<se::StreamExecutor*> ExecutorForPlatformGpuId(
+      se::Platform* gpu_manager, PlatformGpuId platform_gpu_id) {
+    return gpu_manager->ExecutorForDevice(platform_gpu_id.value());
   }
-  static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId(
-      CudaGpuId cuda_gpu_id) {
-    return ExecutorForCudaGpuId(GPUMachineManager(), cuda_gpu_id);
+  static se::port::StatusOr<se::StreamExecutor*> ExecutorForPlatformGpuId(
+      PlatformGpuId platform_gpu_id) {
+    return ExecutorForPlatformGpuId(GPUMachineManager(), platform_gpu_id);
   }
   static se::port::StatusOr<se::StreamExecutor*> ExecutorForTfGpuId(
       TfGpuId tf_gpu_id) {
-    CudaGpuId cuda_gpu_id;
-    TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
-    return ExecutorForCudaGpuId(cuda_gpu_id);
+    PlatformGpuId platform_gpu_id;
+    TF_RETURN_IF_ERROR(
+        GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+    return ExecutorForPlatformGpuId(platform_gpu_id);
   }
 
-  // Verify that the cuda_gpu_id associated with a TfGpuId is legitimate.
+  // Verify that the platform_gpu_id associated with a TfGpuId is legitimate.
   static void CheckValidTfGpuId(TfGpuId tf_gpu_id) {
-    CudaGpuId cuda_gpu_id;
-    TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+    PlatformGpuId platform_gpu_id;
+    TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
     const int visible_device_count = GPUMachineManager()->VisibleDeviceCount();
-    CHECK_LT(cuda_gpu_id.value(), visible_device_count)
-        << "cuda_gpu_id is outside discovered device range."
-        << " TF GPU id: " << tf_gpu_id << " CUDA GPU id: " << cuda_gpu_id
+    CHECK_LT(platform_gpu_id.value(), visible_device_count)
+        << "platform_gpu_id is outside discovered device range."
+        << " TF GPU id: " << tf_gpu_id
+        << " platform GPU id: " << platform_gpu_id
         << " visible device count: " << visible_device_count;
   }
 };
diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
index b186881..3e95374 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
@@ -76,12 +76,16 @@
 // This function is defined for debugging problems with the allocators.
 GPUProcessState::~GPUProcessState() {
   CHECK_EQ(this, instance_);
-  for (auto p : gpu_allocators_) {
-    delete p;
-  }
   instance_ = nullptr;
 }
 
+int GPUProcessState::BusIdForGPU(TfGpuId tf_gpu_id) {
+  // Return the NUMA node associated with the GPU's StreamExecutor.
+  se::StreamExecutor* se =
+      GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie();
+  return se->GetDeviceDescription().numa_node();
+}
+
 Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
                                             TfGpuId tf_gpu_id,
                                             size_t total_bytes) {
@@ -93,64 +97,63 @@
 
   if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
     gpu_allocators_.resize(tf_gpu_id.value() + 1);
-    if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
-      gpu_al_.resize(tf_gpu_id.value() + 1);
   }
 
-  if (gpu_allocators_[tf_gpu_id.value()] == nullptr) {
-    VisitableAllocator* gpu_allocator;
-
+  AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
+  if (allocator_parts.allocator.get() == nullptr) {
     // Validate allocator types.
     if (!allocator_type.empty() && allocator_type != "BFC") {
       LOG(ERROR) << "Invalid allocator type: " << allocator_type;
       return nullptr;
     }
 
-    CudaGpuId cuda_gpu_id;
-    TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
-    gpu_allocator =
-        new GPUBFCAllocator(cuda_gpu_id, total_bytes, options,
+    PlatformGpuId platform_gpu_id;
+    TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+    int bus_id = BusIdForGPU(tf_gpu_id);
+    while (bus_id >= gpu_visitors_.size()) {
+      gpu_visitors_.push_back({});
+    }
+    GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+        GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+        platform_gpu_id,
+        (options.per_process_gpu_memory_fraction() > 1.0 ||
+         options.experimental().use_unified_memory()),
+        gpu_visitors_[bus_id], {});
+    Allocator* gpu_allocator =
+        new GPUBFCAllocator(sub_allocator, total_bytes, options,
                             strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc"));
 
     // If true, checks for memory overwrites by writing
     // distinctive patterns on both ends of allocated memory.
     if (useCudaMemoryGuardAllocator()) {
-      gpu_allocator = new GPUDebugAllocator(gpu_allocator, cuda_gpu_id);
-      gpu_allocator = new GPUNanResetAllocator(gpu_allocator, cuda_gpu_id);
+      gpu_allocator = new GPUDebugAllocator(gpu_allocator, platform_gpu_id);
+      gpu_allocator = new GPUNanResetAllocator(gpu_allocator, platform_gpu_id);
     } else if (useCudaMallocAllocator()) {
       // If true, passes all allocation requests through to cudaMalloc
       // useful for doing memory debugging with tools like cuda-memcheck
       // **WARNING** probably will not work in a multi-gpu scenario
-      gpu_allocator = new GPUcudaMallocAllocator(gpu_allocator, cuda_gpu_id);
+      gpu_allocator =
+          new GPUcudaMallocAllocator(gpu_allocator, platform_gpu_id);
     }
-    gpu_allocators_[tf_gpu_id.value()] = gpu_allocator;
 
-    // If there are any pending AllocVisitors for this bus, add
-    // them now.
-    se::StreamExecutor* se =
-        GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie();
-    int bus_id = se->GetDeviceDescription().numa_node();
-    if (bus_id >= 0 && bus_id < static_cast<int64>(gpu_visitors_.size())) {
-      for (const auto& v : gpu_visitors_[bus_id]) {
-        gpu_allocator->AddAllocVisitor(v);
-      }
-    }
+    Allocator* recording_allocator = nullptr;
     if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
       ProcessState::MemDesc md;
       md.loc = ProcessState::MemDesc::GPU;
-      md.dev_index = cuda_gpu_id.value();
+      md.dev_index = platform_gpu_id.value();
       md.gpu_registered = false;
       md.nic_registered = true;
-      if (static_cast<int64>(gpu_al_.size()) <= tf_gpu_id.value()) {
-        gpu_al_.resize(tf_gpu_id.value() + 1);
-      }
-      gpu_al_[tf_gpu_id.value()] = new internal::RecordingAllocator(
+      recording_allocator = new internal::RecordingAllocator(
           &process_state_->mem_desc_map_, gpu_allocator, md, &mu_);
     }
+    allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator), sub_allocator,
+                       std::unique_ptr<Allocator>(recording_allocator)};
   }
-  if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
-    return gpu_al_[tf_gpu_id.value()];
-  return gpu_allocators_[tf_gpu_id.value()];
+  if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
+    return allocator_parts.recording_allocator.get();
+  } else {
+    return allocator_parts.allocator.get();
+  }
 #else
   LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda.";
   return nullptr;
@@ -172,11 +175,12 @@
     tf_shared_lock lock(mu_);
 
     if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types &&
-        static_cast<int>(cuda_al_.size()) > 0) {
-      return cuda_al_[0];
+        !cuda_host_allocators_.empty() &&
+        cuda_host_allocators_[0].recording_allocator != nullptr) {
+      return cuda_host_allocators_[0].recording_allocator.get();
     }
     if (static_cast<int>(cuda_host_allocators_.size()) > numa_node) {
-      return cuda_host_allocators_[0];
+      return cuda_host_allocators_[0].allocator.get();
     }
   }
 
@@ -190,7 +194,7 @@
   // it knows is valid.
   se::StreamExecutor* se = nullptr;
   for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) {
-    if (gpu_allocators_[i] != nullptr) {
+    if (gpu_allocators_[i].allocator != nullptr) {
       se = GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
       break;
     }
@@ -199,6 +203,15 @@
   CHECK_NE(nullptr, se);
 
   while (static_cast<int>(cuda_host_allocators_.size()) <= numa_node) {
+    while (cuda_host_alloc_visitors_.size() <= numa_node) {
+      cuda_host_alloc_visitors_.push_back({});
+    }
+    while (cuda_host_free_visitors_.size() <= numa_node) {
+      cuda_host_free_visitors_.push_back({});
+    }
+    SubAllocator* sub_allocator = new CUDAHostAllocator(
+        se, numa_node, cuda_host_alloc_visitors_[numa_node],
+        cuda_host_free_visitors_[numa_node]);
     // TODO(zheng-xq): evaluate whether 64GB by default is the best choice.
     int64 cuda_host_mem_limit_in_mb = -1;
     Status status = ReadInt64FromEnvVar("TF_CUDA_HOST_MEM_LIMIT_IN_MB",
@@ -208,62 +221,92 @@
       LOG(ERROR) << "GetCUDAHostAllocator: " << status.error_message();
     }
     int64 cuda_host_mem_limit = cuda_host_mem_limit_in_mb * (1LL << 20);
-    VisitableAllocator* allocator =
-        new BFCAllocator(new CUDAHostAllocator(se), cuda_host_mem_limit,
+    Allocator* allocator =
+        new BFCAllocator(sub_allocator, cuda_host_mem_limit,
                          true /*allow_growth*/, "cuda_host_bfc" /*name*/);
 
-    if (LogMemory::IsEnabled()) {
+    if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
       // Wrap the allocator to track allocation ids for better logging
       // at the cost of performance.
-      allocator = new TrackingVisitableAllocator(allocator, true);
+      allocator = new TrackingAllocator(allocator, true);
     }
-    cuda_host_allocators_.push_back(allocator);
+    cuda_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator),
+                                     sub_allocator,
+                                     std::unique_ptr<Allocator>(nullptr)});
+    AllocatorParts& allocator_parts = cuda_host_allocators_.back();
     if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
       ProcessState::MemDesc md;
       md.loc = ProcessState::MemDesc::CPU;
       md.dev_index = 0;
       md.gpu_registered = true;
       md.nic_registered = false;
-      cuda_al_.push_back(new internal::RecordingAllocator(
-          &process_state_->mem_desc_map_, cuda_host_allocators_.back(), md,
-          &mu_));
+      allocator_parts.recording_allocator.reset(
+          new internal::RecordingAllocator(&process_state_->mem_desc_map_,
+                                           allocator_parts.allocator.get(), md,
+                                           &mu_));
     }
   }
-  if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
-    return cuda_al_[0];
-  return cuda_host_allocators_[0];
+  if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
+    return cuda_host_allocators_[0].recording_allocator.get();
+  } else {
+    return cuda_host_allocators_[0].allocator.get();
+  }
 }
 
 void GPUProcessState::AddGPUAllocVisitor(int bus_id,
-                                         const AllocVisitor& visitor) {
-  CHECK(process_state_);
+                                         const SubAllocator::Visitor& visitor) {
 #if GOOGLE_CUDA
   mutex_lock lock(mu_);
-  for (int i = 0; i < static_cast<int64>(gpu_allocators_.size()); ++i) {
-    se::StreamExecutor* se =
-        GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
-    if (gpu_allocators_[i] &&
-        (se->GetDeviceDescription().numa_node() + 1) == bus_id) {
-      gpu_allocators_[i]->AddAllocVisitor(visitor);
-    }
-  }
+  CHECK(gpu_allocators_.empty())  // Crash OK
+      << "AddGPUAllocVisitor must be called before "
+         "first call to GetGPUAllocator.";
   while (bus_id >= static_cast<int64>(gpu_visitors_.size())) {
-    gpu_visitors_.push_back(std::vector<AllocVisitor>());
+    gpu_visitors_.push_back(std::vector<SubAllocator::Visitor>());
   }
   gpu_visitors_[bus_id].push_back(visitor);
 #endif  // GOOGLE_CUDA
 }
 
+void GPUProcessState::AddCUDAHostAllocVisitor(
+    int numa_node, const SubAllocator::Visitor& visitor) {
+#if GOOGLE_CUDA
+  mutex_lock lock(mu_);
+  CHECK(cuda_host_allocators_.empty())  // Crash OK
+      << "AddCUDAHostAllocVisitor must be called before "
+         "first call to GetCUDAHostAllocator.";
+  while (numa_node >= static_cast<int64>(cuda_host_alloc_visitors_.size())) {
+    cuda_host_alloc_visitors_.push_back(std::vector<SubAllocator::Visitor>());
+  }
+  cuda_host_alloc_visitors_[numa_node].push_back(visitor);
+#endif  // GOOGLE_CUDA
+}
+
+void GPUProcessState::AddCUDAHostFreeVisitor(
+    int numa_node, const SubAllocator::Visitor& visitor) {
+#if GOOGLE_CUDA
+  mutex_lock lock(mu_);
+  CHECK(cuda_host_allocators_.empty())  // Crash OK
+      << "AddCUDAHostFreeVisitor must be called before "
+         "first call to GetCUDAHostAllocator.";
+  while (numa_node >= static_cast<int64>(cuda_host_free_visitors_.size())) {
+    cuda_host_free_visitors_.push_back(std::vector<SubAllocator::Visitor>());
+  }
+  cuda_host_free_visitors_[numa_node].push_back(visitor);
+#endif  // GOOGLE_CUDA
+}
+
 void GPUProcessState::TestOnlyReset() {
-  process_state_->ProcessState::TestOnlyReset();
+  if (process_state_) {
+    process_state_->ProcessState::TestOnlyReset();
+  }
   {
     mutex_lock lock(mu_);
     gpu_device_enabled_ = false;
+    gpu_allocators_.clear();
     gpu_visitors_.clear();
-    gtl::STLDeleteElements(&gpu_allocators_);
-    gtl::STLDeleteElements(&cuda_host_allocators_);
-    gtl::STLDeleteElements(&gpu_al_);
-    gtl::STLDeleteElements(&cuda_al_);
+    cuda_host_allocators_.clear();
+    cuda_host_alloc_visitors_.clear();
+    cuda_host_free_visitors_.clear();
   }
 }
 
diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.h b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
index cb41c3c..43e9a31 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
@@ -32,7 +32,6 @@
 namespace tensorflow {
 
 class Allocator;
-class VisitableAllocator;
 class PoolAllocator;
 
 // Singleton that manages per-process state when GPUs are present.
@@ -72,18 +71,30 @@
 
   virtual Allocator* GetCUDAHostAllocator(int numa_node);
 
-  // Registers a function to be called once on every new Region
-  // allocated by every GPURegionAllocator proximate to the specified
-  // bus.  The AllocVisitor is provided with a memory pointer and the
-  // size of the area it identifies.  The pointer is not guaranteed to
-  // be valid after the call terminates.  The intention is for this
-  // interface to be used for network device memory registration.
-  // "bus_id" is platform-specific.  On many platforms it
-  // should be 0.  On machines with multiple PCIe buses, it should be
-  // the index of one of the PCIe buses.  If the bus_id is invalid,
-  // results are undefined.
-  typedef std::function<void(void*, size_t)> AllocVisitor;
-  virtual void AddGPUAllocVisitor(int bus_id, const AllocVisitor& visitor);
+  // Registers a Visitor to be invoked on new chunks of memory allocated by the
+  // SubAllocator of every GPU proximate to the specified bus.  The AllocVisitor
+  // is provided with a memory pointer, a GPU id, and the size of the area it
+  // identifies.  The pointer is not guaranteed to be valid after the call
+  // terminates.  The intention is for this interface to be used for network
+  // device memory registration.  "bus_id" is platform-specific.  On many
+  // platforms it should be 0.  On machines with multiple PCIe buses, it should
+  // be the index of one of the PCIe buses (maybe the NUMA node at which the
+  // PCIe is rooted).  If the bus_id is invalid, results are undefined.
+  virtual void AddGPUAllocVisitor(int bus_id,
+                                  const SubAllocator::Visitor& visitor);
+
+  // Registers a Visitor to be invoked on new chunks of memory allocated by
+  // the SubAllocator of the CUDAHostAllocator for the given numa_node.
+  virtual void AddCUDAHostAllocVisitor(int numa_node,
+                                       const SubAllocator::Visitor& visitor);
+
+  // Registers a Visitor to be invoked on each chunk handed back for freeing to
+  // the SubAllocator of the CUDAHostAllocator for the given numa_node.
+  virtual void AddCUDAHostFreeVisitor(int numa_node,
+                                      const SubAllocator::Visitor& visitor);
+
+  // Returns bus_id for the given GPU id.
+  virtual int BusIdForGPU(TfGpuId tf_gpu_id);
 
  protected:
   GPUProcessState();
@@ -103,17 +114,22 @@
 
   mutex mu_;
 
-  std::vector<VisitableAllocator*> gpu_allocators_ GUARDED_BY(mu_);
-  std::vector<std::vector<AllocVisitor>> gpu_visitors_ GUARDED_BY(mu_);
-  std::vector<Allocator*> cuda_host_allocators_ GUARDED_BY(mu_);
+  struct AllocatorParts {
+    std::unique_ptr<Allocator> allocator;
+    SubAllocator* sub_allocator;  // owned by allocator
+    std::unique_ptr<Allocator> recording_allocator;
+  };
+  std::vector<AllocatorParts> gpu_allocators_ GUARDED_BY(mu_);
+  std::vector<std::vector<SubAllocator::Visitor>> gpu_visitors_ GUARDED_BY(mu_);
+
+  std::vector<AllocatorParts> cuda_host_allocators_ GUARDED_BY(mu_);
+  std::vector<std::vector<SubAllocator::Visitor>> cuda_host_alloc_visitors_
+      GUARDED_BY(mu_);
+  std::vector<std::vector<SubAllocator::Visitor>> cuda_host_free_visitors_
+      GUARDED_BY(mu_);
 
   virtual ~GPUProcessState();
 
-  // Optional RecordingAllocators that wrap the corresponding
-  // Allocators for runtime attribute use analysis.
-  std::vector<Allocator*> gpu_al_ GUARDED_BY(mu_);
-  std::vector<Allocator*> cuda_al_ GUARDED_BY(mu_);
-
   friend class GPUDeviceTest;
 };
 
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
index 583bff2..6b2f654 100644
--- a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
@@ -31,7 +31,8 @@
       2 /*pool_size_limit*/, false /*auto_resize*/,
       new CUDAHostAllocator(
           platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
-              .ValueOrDie()),
+              .ValueOrDie(),
+          0 /*numa_node*/, {}, {}),
       new NoopRounder, "pool");
 
   EXPECT_EQ(nullptr, pool.AllocateRaw(4 /*alignment*/, 0 /*num_bytes*/));
@@ -49,7 +50,8 @@
       0 /*pool_size_limit*/, false /*auto_resize*/,
       new CUDAHostAllocator(
           platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
-              .ValueOrDie()),
+              .ValueOrDie(),
+          0 /*numa_node*/, {}, {}),
       new NoopRounder, "pool");
 
   EXPECT_EQ(0, pool.get_from_pool_count());
@@ -82,7 +84,8 @@
       0 /*pool_size_limit*/, false /*auto_resize*/,
       new CUDAHostAllocator(
           platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
-              .ValueOrDie()),
+              .ValueOrDie(),
+          0 /*numa_node*/, {}, {}),
       new NoopRounder, "pool");
   for (int i = 0; i < 16; ++i) {
     size_t alignment = 1 << i;
@@ -97,8 +100,8 @@
 
 TEST(PoolAllocatorTest, AutoResize) {
   PoolAllocator pool(2 /*pool_size_limit*/, true /*auto_resize*/,
-                     new BasicCPUAllocator(0 /*numa_node*/), new NoopRounder,
-                     "pool");
+                     new BasicCPUAllocator(0 /*numa_node*/, {}, {}),
+                     new NoopRounder, "pool");
 
   // Alloc/dealloc 10 sizes just a few times, confirming pool size
   // stays at 2.
@@ -123,14 +126,32 @@
 }
 
 TEST(PoolAllocatorTest, CudaHostAllocator) {
+  int alloc_count = 0;
+  int64 alloc_size = 0;
+  SubAllocator::Visitor alloc_visitor =
+      [&alloc_count, &alloc_size](void* ptr, int numa_node, int64 size) {
+        ++alloc_count;
+        alloc_size += size;
+      };
+  int free_count = 0;
+  int64 free_size = 0;
+  SubAllocator::Visitor free_visitor =
+      [&free_count, &free_size](void* ptr, int numa_node, int64 size) {
+        ++free_count;
+        free_size += size;
+      };
   se::Platform* platform =
       se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
-  PoolAllocator pool(
-      2 /*pool_size_limit*/, false /*auto_resize*/,
-      new CUDAHostAllocator(
-          platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
-              .ValueOrDie()),
-      new NoopRounder, "pool");
+  CUDAHostAllocator* sub_allocator = new CUDAHostAllocator(
+      platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
+          .ValueOrDie(),
+      0 /*numa_node*/, {alloc_visitor}, {free_visitor});
+  PoolAllocator pool(2 /*pool_size_limit*/, false /*auto_resize*/,
+                     sub_allocator, new NoopRounder, "pool");
+  EXPECT_EQ(0, alloc_count);
+  EXPECT_EQ(0, alloc_size);
+  EXPECT_EQ(0, free_count);
+  EXPECT_EQ(0, free_size);
 
   // Repeatedly Get a 16-byte value, confirming that there's only
   // one real allocation.
@@ -138,6 +159,10 @@
   EXPECT_EQ(0, pool.get_from_pool_count());
   EXPECT_EQ(1, pool.allocated_count());
   EXPECT_NE(nullptr, p1_16);
+  EXPECT_EQ(1, alloc_count);  // Underlying suballoc of 16 bytes
+  // Each suballocation includes a 16B ChunkPrefix.
+  static const int kChunkPrefixSize = 16;
+  EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size);
   pool.DeallocateRaw(p1_16);
   // Pool contents {16}
   EXPECT_EQ(1, pool.put_count());
@@ -148,6 +173,9 @@
   pool.DeallocateRaw(p2_16);  // Put it back.
   // Pool contents {16}
   EXPECT_EQ(2, pool.put_count());
+  EXPECT_EQ(1, alloc_count);  // Underlying suballoc of 16 bytes
+  EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size);
+  EXPECT_EQ(0, free_count);
 
   // Get two more values of different sizes.
   void* p3_4 = pool.AllocateRaw(4, 4);
@@ -160,6 +188,9 @@
   void* p4_2 = pool.AllocateRaw(4, 2);  // Get a third size buffer.
   EXPECT_NE(nullptr, p4_2);
   EXPECT_EQ(0, pool.evicted_count());
+  EXPECT_EQ(3, alloc_count);
+  EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+  EXPECT_EQ(0, free_count);
 
   // The pool is full: when we put back p4_2, the 16-byte buffer
   // should be evicted since it was least recently inserted.
@@ -167,6 +198,10 @@
   // Pool contents {2, 4}
   EXPECT_EQ(4, pool.put_count());
   EXPECT_EQ(1, pool.evicted_count());
+  EXPECT_EQ(3, alloc_count);
+  EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+  EXPECT_EQ(1, free_count);
+  EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size);
 
   // Re-getting and putting size 2 or 4 should not alter pool size or
   // num-evicted.
@@ -180,12 +215,20 @@
   EXPECT_EQ(6, pool.put_count());
   EXPECT_EQ(3, pool.allocated_count());
   EXPECT_EQ(1, pool.evicted_count());
+  EXPECT_EQ(3, alloc_count);
+  EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+  EXPECT_EQ(1, free_count);
+  EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size);
 
   pool.Clear();
   EXPECT_EQ(0, pool.get_from_pool_count());
   EXPECT_EQ(0, pool.put_count());
   EXPECT_EQ(0, pool.allocated_count());
   EXPECT_EQ(0, pool.evicted_count());
+  EXPECT_EQ(3, alloc_count);
+  EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+  EXPECT_EQ(3, free_count);
+  EXPECT_EQ(16 + 4 + 2 + (free_count * kChunkPrefixSize), free_size);
 }
 
 TEST(PoolAllocatorTest, Pow2Rounder) {
@@ -206,7 +249,8 @@
       2 /*pool_size_limit*/, false /*auto_resize*/,
       new CUDAHostAllocator(
           platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
-              .ValueOrDie()),
+              .ValueOrDie(),
+          0 /*numa_node*/, {}, {}),
       new NoopRounder, "pool");
   EXPECT_EQ("pool", pool.Name());
 }
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index 7f260b3..4475fa9 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -561,6 +561,10 @@
     grappler::GrapplerItem item;
     item.id = "tf_graph";
     graph_->ToGraphDef(&item.graph);
+    // TODO(b/114748242): Add a unit test to test this bug fix.
+    if (flib_def_) {
+      *item.graph.mutable_library() = flib_def_->ToProto();
+    }
 
     item.fetch.insert(item.fetch.end(),
                       options.callable_options.fetch().begin(),
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index df9c3a6..538a706 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -23,12 +23,11 @@
 
 #include <cstdlib>
 #include "tensorflow/core/common_runtime/bfc_allocator.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
-#include "tensorflow/core/framework/allocator_registry.h"
+#include "tensorflow/core/common_runtime/pool_allocator.h"
 #include "tensorflow/core/lib/strings/numbers.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/mem.h"
-#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/numa.h"
 
 #ifndef INTEL_MKL_DNN_ONLY
 #include "i_malloc.h"
@@ -40,20 +39,16 @@
 
 namespace tensorflow {
 
-class MklSubAllocator : public SubAllocator {
+class MklSubAllocator : public BasicCPUAllocator {
  public:
+  MklSubAllocator() : BasicCPUAllocator(port::kNUMANoAffinity, {}, {}) {}
   ~MklSubAllocator() override {}
-
-  void* Alloc(size_t alignment, size_t num_bytes) override {
-    return port::AlignedMalloc(num_bytes, alignment);
-  }
-  void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); }
 };
 
 // CPU allocator that handles small-size allocations by calling
 // suballocator directly. Mostly, it is just a wrapper around a suballocator
 // (that calls malloc and free directly) with support for bookkeeping.
-class MklSmallSizeAllocator : public VisitableAllocator {
+class MklSmallSizeAllocator : public Allocator {
  public:
   MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory,
                         const string& name)
@@ -75,10 +70,6 @@
       CHECK(map_.insert(map_val).second);
       // Increment statistics for small-size allocations.
       IncrementStats(num_bytes);
-      // Call alloc visitors.
-      for (const auto& visitor : alloc_visitors_) {
-        visitor(ptr, num_bytes);
-      }
     }
     return ptr;
   }
@@ -94,9 +85,6 @@
     if (map_iter != map_.end()) {
       // Call free visitors.
       size_t dealloc_bytes = map_iter->second;
-      for (const auto& visitor : free_visitors_) {
-        visitor(ptr, dealloc_bytes);
-      }
       sub_allocator_->Free(ptr, dealloc_bytes);
       DecrementStats(dealloc_bytes);
       map_.erase(map_iter);
@@ -121,16 +109,6 @@
     stats_.Clear();
   }
 
-  void AddAllocVisitor(Visitor visitor) override {
-    mutex_lock l(mutex_);
-    alloc_visitors_.push_back(visitor);
-  }
-
-  void AddFreeVisitor(Visitor visitor) override {
-    mutex_lock l(mutex_);
-    free_visitors_.push_back(visitor);
-  }
-
  private:
   // Increment statistics for the allocator handling small allocations.
   inline void IncrementStats(size_t alloc_size)
@@ -163,15 +141,11 @@
 
   // Allocator stats for small allocs
   AllocatorStats stats_ GUARDED_BY(mutex_);
-
-  // Visitors
-  std::vector<Visitor> alloc_visitors_ GUARDED_BY(mutex_);
-  std::vector<Visitor> free_visitors_ GUARDED_BY(mutex_);
 };
 
 /// CPU allocator for MKL that wraps BFC allocator and intercepts
 /// and redirects memory allocation calls from MKL.
-class MklCPUAllocator : public VisitableAllocator {
+class MklCPUAllocator : public Allocator {
  public:
   // Constructor and other standard functions
 
@@ -284,16 +258,6 @@
     large_size_allocator_->ClearStats();
   }
 
-  void AddAllocVisitor(Visitor visitor) override {
-    small_size_allocator_->AddAllocVisitor(visitor);
-    large_size_allocator_->AddAllocVisitor(visitor);
-  }
-
-  void AddFreeVisitor(Visitor visitor) override {
-    small_size_allocator_->AddFreeVisitor(visitor);
-    large_size_allocator_->AddFreeVisitor(visitor);
-  }
-
  private:
   // Hooks provided by this allocator for memory allocation routines from MKL
 
@@ -330,7 +294,7 @@
   // The alignment that we need for the allocations
   static constexpr const size_t kAlignment = 64;
 
-  VisitableAllocator* large_size_allocator_;     // owned by this class
+  Allocator* large_size_allocator_;              // owned by this class
   MklSmallSizeAllocator* small_size_allocator_;  // owned by this class.
 
   SubAllocator* sub_allocator_;  // not owned by this class
diff --git a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
index f9f3644..6af4ca4 100644
--- a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
+++ b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
@@ -50,8 +50,8 @@
     }
     for (Node* n : matches) {
       AttrSlice n_attrs = n->attrs();
-      auto base_make_node = [n, g, &n_attrs](const string& op,
-                                             const string& name) {
+      auto base_make_node = [n, &n_attrs](const string& op,
+                                          const string& name) {
         NodeBuilder node_builder(name, op);
         node_builder.Device(n->requested_device());
         string colo;
@@ -60,7 +60,7 @@
         }
         return node_builder;
       };
-      auto make_node = [n, g, &n_attrs, &base_make_node](string op) {
+      auto make_node = [n, g, &base_make_node](string op) {
         return base_make_node(
             op, g->NewName(strings::StrCat(n->name(), "/Internal")));
       };
diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc
index fdad8de..66dc8f3 100644
--- a/tensorflow/core/common_runtime/pool_allocator.cc
+++ b/tensorflow/core/common_runtime/pool_allocator.cc
@@ -40,8 +40,7 @@
       auto_resize_(auto_resize),
       pool_size_limit_(pool_size_limit),
       allocator_(allocator),
-      size_rounder_(size_rounder),
-      allocation_begun_(false) {
+      size_rounder_(size_rounder) {
   if (auto_resize) {
     CHECK_LT(size_t{0}, pool_size_limit)
         << "size limit must be > 0 if auto_resize is true.";
@@ -93,7 +92,6 @@
 }  // namespace
 
 void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
-  if (!allocation_begun_) allocation_begun_ = true;
   if (num_bytes == 0) return nullptr;
 
   // If alignment is larger than kPoolAlignment, increase num_bytes so that we
@@ -129,9 +127,6 @@
     return PrepareChunk(r, alignment, num_bytes);
   } else {
     void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes);
-    for (const auto& v : alloc_visitors_) {
-      v(ptr, num_bytes);
-    }
     return PrepareChunk(ptr, alignment, num_bytes);
   }
 }
@@ -141,9 +136,6 @@
   ChunkPrefix* cp = FindPrefix(ptr);
   CHECK_LE((void*)cp, (void*)ptr);
   if (!has_size_limit_ && !auto_resize_) {
-    for (const auto& v : free_visitors_) {
-      v(cp, cp->num_bytes);
-    }
     allocator_->Free(cp, cp->num_bytes);
   } else {
     mutex_lock lock(mutex_);
@@ -164,9 +156,6 @@
     mutex_lock lock(mutex_);
     for (auto iter : pool_) {
       PtrRecord* pr = iter.second;
-      for (const auto& v : free_visitors_) {
-        v(pr->ptr, pr->num_bytes);
-      }
       allocator_->Free(pr->ptr, pr->num_bytes);
       delete pr;
     }
@@ -221,9 +210,6 @@
     DCHECK(iter != pool_.end());
   }
   pool_.erase(iter);
-  for (const auto& v : free_visitors_) {
-    v(prec->ptr, prec->num_bytes);
-  }
   allocator_->Free(prec->ptr, prec->num_bytes);
   delete prec;
   ++evicted_count_;
@@ -269,28 +255,19 @@
   }
 }
 
-void PoolAllocator::AddAllocVisitor(Visitor visitor) {
-  mutex_lock lock(mutex_);
-  CHECK(!allocation_begun_)
-      << "AddAllocVisitor may not be called after pool allocation "
-      << "has begun.";
-  alloc_visitors_.push_back(visitor);
-}
-
-void PoolAllocator::AddFreeVisitor(Visitor visitor) {
-  mutex_lock lock(mutex_);
-  CHECK(!allocation_begun_)
-      << "AddFreeVisitor may not be called after pool allocation "
-      << "has begun.";
-  free_visitors_.push_back(visitor);
-}
-
 void* BasicCPUAllocator::Alloc(size_t alignment, size_t num_bytes) {
-  return port::AlignedMalloc(num_bytes, static_cast<int>(alignment));
+  void* ptr = nullptr;
+  if (num_bytes > 0) {
+    ptr = port::AlignedMalloc(num_bytes, static_cast<int>(alignment));
+    VisitAlloc(ptr, numa_node_, num_bytes);
+  }
+  return ptr;
 }
 
 void BasicCPUAllocator::Free(void* ptr, size_t num_bytes) {
-  port::AlignedFree(ptr);
+  if (num_bytes > 0) {
+    VisitFree(ptr, numa_node_, num_bytes);
+    port::AlignedFree(ptr);
+  }
 }
-
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/pool_allocator.h b/tensorflow/core/common_runtime/pool_allocator.h
index 6077344..5b4623b 100644
--- a/tensorflow/core/common_runtime/pool_allocator.h
+++ b/tensorflow/core/common_runtime/pool_allocator.h
@@ -16,14 +16,13 @@
 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_
 #define TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_
 
-// Simple LRU pool allocators for various flavors of CPU RAM that
-// implement the VisitableAllocator interface.
+// Simple LRU pool allocators for various flavors of CPU RAM.
 
 #include <atomic>
 #include <map>
 #include <memory>
 #include <vector>
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/lib/core/bits.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
@@ -41,7 +40,7 @@
 
 // Size-limited pool of memory buffers obtained from a SubAllocator
 // instance.  Pool eviction policy is LRU.
-class PoolAllocator : public VisitableAllocator {
+class PoolAllocator : public Allocator {
  public:
   // "pool_size_limit" is the maximum number of returned, re-usable
   // memory buffers to keep in the pool.  If pool_size_limit == 0, the
@@ -64,14 +63,6 @@
 
   void DeallocateRaw(void* ptr) override;
 
-  // REQUIRES: The following functions may only be called prior
-  // to the first Allocate*() call.  Once allocation has begun, it is
-  // illegal to register another visitor.
-
-  void AddAllocVisitor(Visitor visitor) override;
-
-  void AddFreeVisitor(Visitor visitor) override;
-
   // Allocate an unused memory region of size "num_bytes".  Fetch from
   // the pool if available, otherwise call allocator_.
   void* Get(size_t num_bytes);
@@ -141,12 +132,6 @@
   int64 put_count_ GUARDED_BY(mutex_) = 0;
   int64 allocated_count_ GUARDED_BY(mutex_) = 0;
   int64 evicted_count_ GUARDED_BY(mutex_) = 0;
-  // Write access to these is guarded by mutex_, but not read
-  // access. They may only be modified prior to the first
-  // allocation.  Later attempts to modify will fail.
-  std::vector<Visitor> alloc_visitors_;
-  std::vector<Visitor> free_visitors_;
-  std::atomic<bool> allocation_begun_;
 };
 
 // Do-nothing rounder. Passes through sizes unchanged.
@@ -166,7 +151,9 @@
 class BasicCPUAllocator : public SubAllocator {
  public:
   // Argument numa_node is currently ignored.
-  explicit BasicCPUAllocator(int numa_node) : numa_node_(numa_node) {}
+  BasicCPUAllocator(int numa_node, const std::vector<Visitor>& alloc_visitors,
+                    const std::vector<Visitor>& free_visitors)
+      : SubAllocator(alloc_visitors, free_visitors), numa_node_(numa_node) {}
 
   ~BasicCPUAllocator() override {}
 
@@ -176,6 +163,8 @@
 
  private:
   int numa_node_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(BasicCPUAllocator);
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc
index 447338e..bcaa37f 100644
--- a/tensorflow/core/common_runtime/process_state.cc
+++ b/tensorflow/core/common_runtime/process_state.cc
@@ -71,20 +71,28 @@
   return MemDesc();
 }
 
-VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) {
+Allocator* ProcessState::GetCPUAllocator(int numa_node) {
   CHECK_GE(numa_node, 0);
   if (!numa_enabled_) numa_node = 0;
   mutex_lock lock(mu_);
   while (cpu_allocators_.size() <= static_cast<size_t>(numa_node)) {
+    // If visitors have been defined we need an Allocator built from
+    // a SubAllocator.  Prefer BFCAllocator, but fall back to PoolAllocator
+    // depending on env var setting.
+    const bool alloc_visitors_defined =
+        (!cpu_alloc_visitors_.empty() || !cpu_free_visitors_.empty());
     bool use_bfc_allocator = false;
-    // TODO(reedwm): Switch default to BGFAllocator if it's at least as fast and
-    // efficient.
-    Status status = ReadBoolFromEnvVar("TF_CPU_ALLOCATOR_USE_BFC", false,
-                                       &use_bfc_allocator);
+    Status status = ReadBoolFromEnvVar(
+        "TF_CPU_ALLOCATOR_USE_BFC", alloc_visitors_defined, &use_bfc_allocator);
     if (!status.ok()) {
       LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
     }
-    VisitableAllocator* allocator;
+    Allocator* allocator = nullptr;
+    SubAllocator* sub_allocator =
+        (alloc_visitors_defined || use_bfc_allocator)
+            ? new BasicCPUAllocator(numa_enabled_ ? numa_node : -1,
+                                    cpu_alloc_visitors_, cpu_free_visitors_)
+            : nullptr;
     if (use_bfc_allocator) {
       // TODO(reedwm): evaluate whether 64GB by default is the best choice.
       int64 cpu_mem_limit_in_mb = -1;
@@ -95,34 +103,63 @@
         LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
       }
       int64 cpu_mem_limit = cpu_mem_limit_in_mb * (1LL << 20);
-      allocator = new BFCAllocator(
-          new BasicCPUAllocator(numa_enabled_ ? numa_node : -1), cpu_mem_limit,
-          true /*allow_growth*/, "bfc_cpu_allocator_for_gpu" /*name*/);
+      DCHECK(sub_allocator);
+      allocator =
+          new BFCAllocator(sub_allocator, cpu_mem_limit, true /*allow_growth*/,
+                           "bfc_cpu_allocator_for_gpu" /*name*/);
       VLOG(2) << "Using BFCAllocator with memory limit of "
               << cpu_mem_limit_in_mb << " MB for ProcessState CPU allocator";
-    } else {
-      allocator = new PoolAllocator(
-          100 /*pool_size_limit*/, true /*auto_resize*/,
-          new BasicCPUAllocator(numa_enabled_ ? numa_node : -1),
-          new NoopRounder, "cpu_pool");
+    } else if (alloc_visitors_defined) {
+      DCHECK(sub_allocator);
+      allocator =
+          new PoolAllocator(100 /*pool_size_limit*/, true /*auto_resize*/,
+                            sub_allocator, new NoopRounder, "cpu_pool");
       VLOG(2) << "Using PoolAllocator for ProcessState CPU allocator "
               << "numa_enabled_=" << numa_enabled_
               << " numa_node=" << numa_node;
+    } else {
+      DCHECK(!sub_allocator);
+      allocator = cpu_allocator();
     }
-    if (LogMemory::IsEnabled()) {
+    if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
       // Wrap the allocator to track allocation ids for better logging
       // at the cost of performance.
-      allocator = new TrackingVisitableAllocator(allocator, true);
+      allocator = new TrackingAllocator(allocator, true);
     }
     cpu_allocators_.push_back(allocator);
+    if (!sub_allocator) {
+      DCHECK(cpu_alloc_visitors_.empty() && cpu_free_visitors_.empty());
+    }
   }
   return cpu_allocators_[numa_node];
 }
 
+void ProcessState::AddCPUAllocVisitor(SubAllocator::Visitor visitor) {
+  VLOG(1) << "AddCPUAllocVisitor";
+  mutex_lock lock(mu_);
+  CHECK_EQ(0, cpu_allocators_.size())  // Crash OK
+      << "AddCPUAllocVisitor must be called prior to first call to "
+         "ProcessState::GetCPUAllocator";
+  cpu_alloc_visitors_.push_back(std::move(visitor));
+}
+
+void ProcessState::AddCPUFreeVisitor(SubAllocator::Visitor visitor) {
+  mutex_lock lock(mu_);
+  CHECK_EQ(0, cpu_allocators_.size())  // Crash OK
+      << "AddCPUFreeVisitor must be called prior to first call to "
+         "ProcessState::GetCPUAllocator";
+  cpu_free_visitors_.push_back(std::move(visitor));
+}
+
 void ProcessState::TestOnlyReset() {
   mutex_lock lock(mu_);
+  // Don't delete this value because it's static.
+  Allocator* default_cpu_allocator = cpu_allocator();
   mem_desc_map_.clear();
-  gtl::STLDeleteElements(&cpu_allocators_);
+  for (Allocator* a : cpu_allocators_) {
+    if (a != default_cpu_allocator) delete a;
+  }
+  cpu_allocators_.clear();
   gtl::STLDeleteElements(&cpu_al_);
 }
 
diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h
index 2892677..cac312d 100644
--- a/tensorflow/core/common_runtime/process_state.h
+++ b/tensorflow/core/common_runtime/process_state.h
@@ -30,7 +30,6 @@
 namespace tensorflow {
 
 class Allocator;
-class VisitableAllocator;
 class PoolAllocator;
 
 // Singleton that manages per-process state, e.g. allocation of
@@ -65,7 +64,15 @@
 
   // Returns the one CPUAllocator used for the given numa_node.
   // TEMPORARY: ignores numa_node.
-  VisitableAllocator* GetCPUAllocator(int numa_node);
+  Allocator* GetCPUAllocator(int numa_node);
+
+  // Registers alloc visitor for the CPU allocator(s).
+  // REQUIRES: must be called before GetCPUAllocator.
+  void AddCPUAllocVisitor(SubAllocator::Visitor v);
+
+  // Registers free visitor for the CPU allocator(s).
+  // REQUIRES: must be called before GetCPUAllocator.
+  void AddCPUFreeVisitor(SubAllocator::Visitor v);
 
   typedef std::unordered_map<const void*, MemDesc> MDMap;
 
@@ -87,7 +94,9 @@
 
   mutex mu_;
 
-  std::vector<VisitableAllocator*> cpu_allocators_ GUARDED_BY(mu_);
+  std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_);
+  std::vector<SubAllocator::Visitor> cpu_alloc_visitors_ GUARDED_BY(mu_);
+  std::vector<SubAllocator::Visitor> cpu_free_visitors_ GUARDED_BY(mu_);
 
   virtual ~ProcessState();
 
diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h
index 103eee0..c00789a 100644
--- a/tensorflow/core/common_runtime/renamed_device.h
+++ b/tensorflow/core/common_runtime/renamed_device.h
@@ -58,6 +58,15 @@
     return underlying_->GetAllocator(attr);
   }
 
+  Allocator* GetScopedAllocator(AllocatorAttributes attr,
+                                int64 step_id) override {
+    return underlying_->GetScopedAllocator(attr, step_id);
+  }
+
+  ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {
+    return underlying_->GetScopedAllocatorMgr();
+  }
+
   const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
     return underlying_->eigen_cpu_device();
   }
@@ -72,9 +81,10 @@
     return underlying_->MakeGpuDevice();
   }
 
-  void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
-                             DeviceContext* dc, Allocator* allocator) override {
-    underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
+  Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
+                               DeviceContext* dc,
+                               Allocator* allocator) override {
+    return underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
   }
 
   Status MakeTensorFromProto(const TensorProto& tensor_proto,
diff --git a/tensorflow/core/common_runtime/session_ref.cc b/tensorflow/core/common_runtime/session_ref.cc
deleted file mode 100644
index b931ef4..0000000
--- a/tensorflow/core/common_runtime/session_ref.cc
+++ /dev/null
@@ -1,170 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/common_runtime/session_ref.h"
-
-#include <utility>
-
-namespace tensorflow {
-
-namespace {
-
-// Scope helper to track active calls and manage session lifetime.
-struct RunCounter {
-  std::shared_ptr<Session> session;
-  uint64* value;
-  mutex* m;
-  condition_variable* cv;
-
-  explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
-                      condition_variable* cv)
-      : session(std::move(s)), value(v), m(m), cv(cv) {
-    mutex_lock l(*m);
-    ++*value;
-  }
-
-  ~RunCounter() {
-    mutex_lock l(*m);
-    if (--*value == 0) {
-      cv->notify_all();
-    }
-  }
-};
-
-}  // namespace
-
-Status SessionRef::CheckNotClosed() {
-  mutex_lock l(run_lock_);
-  if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
-  return ::tensorflow::Status::OK();
-}
-
-Status SessionRef::Run(const RunOptions& run_options,
-                       const std::vector<std::pair<string, Tensor> >& inputs,
-                       const std::vector<string>& output_tensor_names,
-                       const std::vector<string>& target_node_names,
-                       std::vector<Tensor>* outputs,
-                       RunMetadata* run_metadata) {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
-  return rc.session->Run(run_options, inputs, output_tensor_names,
-                         target_node_names, outputs, run_metadata);
-}
-
-Status SessionRef::Create(const GraphDef& graph) {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
-  return rc.session->Create(graph);
-}
-
-Status SessionRef::Create(const RunOptions& run_options,
-                          const GraphDef& graph) {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
-  return rc.session->Create(run_options, graph);
-}
-
-Status SessionRef::Extend(const RunOptions& run_options,
-                          const GraphDef& graph) {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
-  return rc.session->Extend(run_options, graph);
-}
-
-Status SessionRef::Extend(const GraphDef& graph) {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
-  return rc.session->Extend(graph);
-}
-
-Status SessionRef::Close(const RunOptions& run_options) {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  mutex_lock l(run_lock_);
-  Status status = session_->Close(run_options);
-  session_.reset();
-  while (run_count_ > 0) {
-    run_finished_.wait(l);
-  }
-  return status;
-}
-
-Status SessionRef::Close() {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  mutex_lock l(run_lock_);
-  Status status = session_->Close();
-  session_.reset();
-  while (run_count_ > 0) {
-    run_finished_.wait(l);
-  }
-  return status;
-}
-
-Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
-                       const std::vector<string>& output_tensor_names,
-                       const std::vector<string>& target_node_names,
-                       std::vector<Tensor>* outputs) {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
-  return rc.session->Run(inputs, output_tensor_names, target_node_names,
-                         outputs);
-}
-
-Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
-  return rc.session->ListDevices(response);
-}
-
-Status SessionRef::PRunSetup(const std::vector<string>& input_names,
-                             const std::vector<string>& output_names,
-                             const std::vector<string>& target_nodes,
-                             string* handle) {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
-  return rc.session->PRunSetup(input_names, output_names, target_nodes, handle);
-}
-
-Status SessionRef::PRun(const string& handle,
-                        const std::vector<std::pair<string, Tensor> >& inputs,
-                        const std::vector<string>& output_names,
-                        std::vector<Tensor>* outputs) {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
-  return rc.session->PRun(handle, inputs, output_names, outputs);
-}
-
-Status SessionRef::MakeCallable(const CallableOptions& callable_options,
-                                CallableHandle* out_handle) {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
-  return rc.session->MakeCallable(callable_options, out_handle);
-}
-
-Status SessionRef::RunCallable(CallableHandle handle,
-                               const std::vector<Tensor>& feed_tensors,
-                               std::vector<Tensor>* fetch_tensors,
-                               RunMetadata* run_metadata) {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
-  return rc.session->RunCallable(handle, feed_tensors, fetch_tensors,
-                                 run_metadata);
-}
-
-Status SessionRef::ReleaseCallable(CallableHandle handle) {
-  TF_RETURN_IF_ERROR(CheckNotClosed());
-  RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
-  return rc.session->ReleaseCallable(handle);
-}
-
-}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index 836cb8e..a70ab93 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -27,6 +27,7 @@
 #include "tensorflow/core/lib/strings/scanner.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/ptr_util.h"
 
 namespace tensorflow {
 namespace {
@@ -40,46 +41,24 @@
 };
 }  // namespace
 
-NodeExecStatsWrapper::NodeExecStatsWrapper(const string& node_name)
-    : NodeExecStatsWrapper(new NodeExecStats) {
-  stats_->set_node_name(node_name);
-}
-NodeExecStatsWrapper::NodeExecStatsWrapper(NodeExecStats* stats)
-    : stats_(stats) {}
-
-void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* v) {
-  DCHECK(v);
-  NodeOutput* no = stats_->add_output();
-  no->set_slot(slot);
-  v->FillDescription(no->mutable_tensor_description());
+NodeExecStatsWrapper::NodeExecStatsWrapper(
+    const Node* node, StepStatsCollector* step_stats_collector)
+    : NodeExecStatsWrapper(MakeUnique<NodeExecStats>(), node,
+                           step_stats_collector) {
+  stats_->set_node_name(node->name());
 }
 
-void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
-  for (const auto& allocator_pair : ctx->wrapped_allocators()) {
-    AddAllocation(allocator_pair.first, allocator_pair.second);
-  }
-  auto* ms = stats_->mutable_memory_stats();
-  ms->set_temp_memory_size(ctx->temp_memory_allocated());
-  for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
-    ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
-  }
-  ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
-}
+NodeExecStatsWrapper::NodeExecStatsWrapper(
+    std::unique_ptr<NodeExecStats> stats, const Node* node,
+    StepStatsCollector* step_stats_collector)
+    : stats_(std::move(stats)),
+      node_(node),
+      step_stats_collector_(step_stats_collector) {}
 
-void NodeExecStatsWrapper::SetReferencedTensors(
-    const TensorReferenceVector& tensors) {
-  // be careful not to increment the reference count on any tensor
-  // while recording the information
-  for (size_t i = 0; i < tensors.size(); ++i) {
-    AllocationDescription* description = stats_->add_referenced_tensor();
-    tensors.at(i).FillDescription(description);
-  }
-}
-
-// TODO(tucker): merge with the DetailText function in session.cc
-// in a common location.
-bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) {
-  bool is_transfer_node = false;
+void NodeExecStatsWrapper::Done(const string& device) {
+  // TODO(tucker): merge with the DetailText function in session.cc in a common
+  // location.
+  DCHECK(node_);
   string memory;
   for (auto& all : stats_->memory()) {
     int64 tot = all.total_bytes();
@@ -96,31 +75,96 @@
       }
     }
   }
-  const AttrSlice attrs = node->attrs();
+  const AttrSlice attrs = node_->attrs();
   string text;
-  if (IsSend(node)) {
+  if (IsSend(node_)) {
     string tensor_name;
     TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
     string recv_device;
     TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
-    text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+    text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(),
                            "(", tensor_name, " @", recv_device);
-    is_transfer_node = true;
-  } else if (IsRecv(node)) {
+  } else if (IsRecv(node_)) {
     string tensor_name;
     TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
     string send_device;
     TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
-    text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+    text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(),
                            "(", tensor_name, " @", send_device);
-    is_transfer_node = true;
   } else {
     text =
-        strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
-                        str_util::Join(node->requested_inputs(), ", "), ")");
+        strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(",
+                        str_util::Join(node_->requested_inputs(), ", "), ")");
   }
   stats_->set_timeline_label(text);
-  return is_transfer_node;
+  step_stats_collector_->Save(device, this);
+}
+
+void NodeExecStatsWrapper::RecordExecutorStarted() {
+  int64 now_nanos = Env::Default()->NowNanos();
+  stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
+  stats_->set_all_start_nanos(now_nanos);
+}
+
+void NodeExecStatsWrapper::RecordComputeStarted() {
+  int64 now_nanos = Env::Default()->NowNanos();
+  DCHECK_NE(stats_->all_start_micros(), 0);
+  DCHECK_NE(stats_->all_start_nanos(), 0);
+  stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+                                  stats_->all_start_micros());
+  stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::RecordComputeEnded() {
+  int64 now_nanos = Env::Default()->NowNanos();
+  DCHECK_NE(stats_->all_start_micros(), 0);
+  DCHECK_NE(stats_->all_start_nanos(), 0);
+  stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+                                stats_->all_start_micros());
+  stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::RecordExecutorEnded() {
+  int64 now_nanos = Env::Default()->NowNanos();
+  DCHECK_NE(stats_->all_start_micros(), 0);
+  DCHECK_NE(stats_->all_start_nanos(), 0);
+  stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+                                 stats_->all_start_micros());
+  stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::SetScheduled(int64 nanos) {
+  stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
+  stats_->set_scheduled_nanos(nanos);
+}
+
+void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
+  for (const auto& allocator_pair : ctx->wrapped_allocators()) {
+    AddAllocation(allocator_pair.first, allocator_pair.second);
+  }
+  auto* ms = stats_->mutable_memory_stats();
+  ms->set_temp_memory_size(ctx->temp_memory_allocated());
+  for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
+    ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
+  }
+  ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+}
+
+void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* tensor) {
+  DCHECK(tensor);
+  NodeOutput* node_output = stats_->add_output();
+  node_output->set_slot(slot);
+  tensor->FillDescription(node_output->mutable_tensor_description());
+}
+
+void NodeExecStatsWrapper::SetReferencedTensors(
+    const TensorReferenceVector& tensors) {
+  // be careful not to increment the reference count on any tensor
+  // while recording the information
+  for (size_t i = 0; i < tensors.size(); ++i) {
+    AllocationDescription* description = stats_->add_referenced_tensor();
+    tensors.at(i).FillDescription(description);
+  }
 }
 
 void NodeExecStatsWrapper::AddAllocation(
@@ -150,8 +194,8 @@
   allocations_.clear();
 }
 
-StepStatsCollector::StepStatsCollector(StepStats* ss)
-    : finalized_(false), step_stats_(ss) {}
+StepStatsCollector::StepStatsCollector(StepStats* step_stats)
+    : finalized_(false), step_stats_(step_stats) {}
 
 static int ExtractGpuWithStreamAll(string device_name) {
   // Check if the device name matches the ".*gpu:(\\d+)/stream:all$" regexp,
@@ -338,30 +382,42 @@
   }
 }
 
-void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
-  Save(device, new NodeExecStatsWrapper(nt));
+void StepStatsCollector::Save(const string& device,
+                              NodeExecStats* node_stats_pb) {
+  Save(device,
+       new NodeExecStatsWrapper(std::unique_ptr<NodeExecStats>(node_stats_pb),
+                                nullptr, this));
 }
 
 void StepStatsCollector::Save(const string& device,
-                              NodeExecStatsWrapper* stats) {
-  if (!stats) return;
-  VLOG(1) << "Save dev " << device << " nt " << stats->stats();
+                              NodeExecStatsWrapper* node_stats) {
+  if (!node_stats) return;
+  VLOG(1) << "Save dev " << device << " node stats " << node_stats->stats();
   {
     mutex_lock l(mu_);
     if (finalized_) {
       LOG(WARNING) << "stats saved after finalize will not be collected.";
     }
-    if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) {
+    if (!step_stats_ || collected_nodes_ >= kMaxCollectedNodes) {
       VLOG(1) << "step_stats_ nullptr or already collected too many nodes.";
-      delete stats;
+      delete node_stats;
       return;
     }
-    auto& dss = dev_stats_[device];
-    dss.push_back(std::unique_ptr<NodeExecStatsWrapper>(stats));
-    collectedNodes++;
+    auto& device_stats = dev_stats_[device];
+    device_stats.push_back(std::unique_ptr<NodeExecStatsWrapper>(node_stats));
+    collected_nodes_++;
   }
 }
 
+NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats(
+    const Node* node) {
+  // Only collect statistics for non-transfer nodes.
+  if (IsSend(node) || IsRecv(node)) {
+    return nullptr;
+  }
+  return new NodeExecStatsWrapper(node, this);
+}
+
 string StepStatsCollector::ReportAllocsOnResourceExhausted(const string& err) {
   mutex_lock l(mu_);
   if (err.find("OOM") == err.npos) {
@@ -446,12 +502,12 @@
   FinalizeInternal();
 }
 
-void StepStatsCollector::FinalizeAndSwap(StepStats* ss) {
+void StepStatsCollector::FinalizeAndSwap(StepStats* step_stats) {
   mutex_lock l(mu_);
   CHECK(step_stats_);
   FinalizeInternal();
-  ss->Swap(step_stats_);
-  collectedNodes = 0;
+  step_stats->Swap(step_stats_);
+  collected_nodes_ = 0;
 }
 
 void StepStatsCollector::FinalizeInternal() {
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h
index 7206fbf..4365b11 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.h
+++ b/tensorflow/core/common_runtime/step_stats_collector.h
@@ -36,81 +36,78 @@
 class NodeExecStats;
 class OpKernelContext;
 class StepStats;
+class StepStatsCollector;
 class Tensor;
 class TrackingAllocator;
 
-// Wraps NodeExecStats and adds allocation to it.
-class NodeExecStatsWrapper {
+// Statistics collection interface for individual node execution.
+//
+// See `NodeExecStatsWrapper` for a concrete implementation of this interface
+// that interfaces with the `Session` layer.
+class NodeExecStatsInterface {
  public:
-  NodeExecStatsWrapper(const string& node_name);
-  // Owns 'stats'.
-  NodeExecStatsWrapper(NodeExecStats* stats);
+  virtual ~NodeExecStatsInterface() {}
 
-  // Destructor calls Finalize() to release the TrackingAllocators.
-  ~NodeExecStatsWrapper() { Finalize(); }
-
-  // Records the absolute time in nanoseconds at which this node became
-  // runnable (i.e. was scheduled for execution).
-  void SetScheduled(int64 nanos) {
-    stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
-    stats_->set_scheduled_nanos(nanos);
-  }
+  // Called when the statistics collection for the node has finished. Once this
+  // method is called, the caller should not make assumptions about the validity
+  // of this object.
+  virtual void Done(const string& device) = 0;
 
   // Called immediately after this node starts being processed by the executor.
-  void RecordExecutorStarted() {
-    int64 now_nanos = Env::Default()->NowNanos();
-    stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
-    stats_->set_all_start_nanos(now_nanos);
-  }
+  virtual void RecordExecutorStarted() = 0;
 
   // Called immediately before this node's `Compute()` or `ComputeAsync()`
   // method is called.
-  void RecordComputeStarted() {
-    int64 now_nanos = Env::Default()->NowNanos();
-    DCHECK_NE(stats_->all_start_micros(), 0);
-    DCHECK_NE(stats_->all_start_nanos(), 0);
-    stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
-                                    stats_->all_start_micros());
-    stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
-  }
+  virtual void RecordComputeStarted() = 0;
 
   // Called immediately after this node's `Compute()` method returned (or, for
   // asynchronous operations, the callback passed to its `ComputeAsync()` method
   // was called).
-  void RecordComputeEnded() {
-    int64 now_nanos = Env::Default()->NowNanos();
-    DCHECK_NE(stats_->all_start_micros(), 0);
-    DCHECK_NE(stats_->all_start_nanos(), 0);
-    stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
-                                  stats_->all_start_micros());
-    stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
-  }
+  virtual void RecordComputeEnded() = 0;
 
   // Called immediately after this executor finishes processing this node.
-  void RecordExecutorEnded() {
-    int64 now_nanos = Env::Default()->NowNanos();
-    DCHECK_NE(stats_->all_start_micros(), 0);
-    DCHECK_NE(stats_->all_start_nanos(), 0);
-    stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
-                                   stats_->all_start_micros());
-    stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
-  }
-
-  // Records information about the tensor produced by this node at the given
-  // output slot.
-  void SetOutput(int slot, const Tensor* v);
+  virtual void RecordExecutorEnded() = 0;
 
   // Records information about the memory allocated during the execution of this
   // node.
-  void SetMemory(OpKernelContext* ctx);
+  virtual void SetMemory(OpKernelContext* ctx) = 0;
+
+  // Records information about the tensor produced by this node at the given
+  // output slot.
+  virtual void SetOutput(int slot, const Tensor* tensor) = 0;
 
   // Records information about the tensors that were accessed during the
   // execution of this node.
-  void SetReferencedTensors(const TensorReferenceVector& tensors);
+  virtual void SetReferencedTensors(const TensorReferenceVector& tensors) = 0;
 
-  // Sets the timeline_label field of the wrapped NodeExecStats, using data
-  // from *node. Returns true iff the node is a transfer node.
-  bool SetTimelineLabel(const Node* node);
+  // Records the absolute time in nanoseconds at which this node became
+  // runnable (i.e. was scheduled for execution).
+  virtual void SetScheduled(int64 nanos) = 0;
+};
+
+// Wraps NodeExecStats and adds allocation to it.
+class NodeExecStatsWrapper : public NodeExecStatsInterface {
+ public:
+  // Does not take ownership of `node` or `step_stats_collector`.
+  NodeExecStatsWrapper(const Node* node,
+                       StepStatsCollector* step_stats_collector);
+
+  // Takes ownership of 'stats' but not `node` or `step_stats_collector`.
+  NodeExecStatsWrapper(std::unique_ptr<NodeExecStats> stats, const Node* node,
+                       StepStatsCollector* step_stats_collector);
+
+  // Destructor calls Finalize() to release the TrackingAllocators.
+  ~NodeExecStatsWrapper() { Finalize(); }
+
+  void Done(const string& device) override;
+  void RecordExecutorStarted() override;
+  void RecordComputeStarted() override;
+  void RecordComputeEnded() override;
+  void RecordExecutorEnded() override;
+  void SetMemory(OpKernelContext* ctx) override;
+  void SetOutput(int slot, const Tensor* tensor) override;
+  void SetReferencedTensors(const TensorReferenceVector& tensors) override;
+  void SetScheduled(int64 nanos) override;
 
  private:
   friend class StepStatsCollector;
@@ -128,9 +125,11 @@
   gtl::InlinedVector<std::pair<AllocatorMemoryUsed*, TrackingAllocator*>, 2>
       allocations_;
   std::unique_ptr<NodeExecStats> stats_;
+  const Node* const node_;                          // Not owned.
+  StepStatsCollector* const step_stats_collector_;  // Not owned.
 };
 
-// Statistics collection interface for individual node execution.
+// Statistics collection interface for step execution.
 //
 // See `StepStatsCollector` for a concrete implementation of this interface
 // that interfaces with the `Session` layer.
@@ -138,8 +137,9 @@
  public:
   virtual ~StepStatsCollectorInterface() {}
 
-  // Saves `stats` to the collector.
-  virtual void Save(const string& device, NodeExecStatsWrapper* stats) = 0;
+  // Creates an instance of `NodeExecStatsInterface` that should be used for
+  // collecting statistics about individual node execution.
+  virtual NodeExecStatsInterface* CreateNodeExecStats(const Node* node) = 0;
 
   // Generates a string reporting the currently used memory based
   // on ResourceExhausted OOM `err` message.
@@ -154,8 +154,8 @@
 // Each DeviceStats object holds multiple NodeExecStats.
 class StepStatsCollector : public StepStatsCollectorInterface {
  public:
-  // Does not take ownership of `ss`.
-  explicit StepStatsCollector(StepStats* ss);
+  // Does not take ownership of `step_stats`.
+  explicit StepStatsCollector(StepStats* step_stats);
 
   // BuildCostModel builds or updates a CostModel managed by cost_model_manager,
   // using the currently collected DeviceStats associated with the devices in
@@ -164,11 +164,12 @@
       CostModelManager* cost_model_manager,
       const std::unordered_map<string, const Graph*>& device_map);
 
-  // Save saves nt to the DeviceStats object associated with device.
+  // Saves node statistics to the DeviceStats object associated with device.
   // Should be called before Finalize.
-  void Save(const string& device, NodeExecStats* nt);
-  void Save(const string& device, NodeExecStatsWrapper* stats) override;
+  void Save(const string& device, NodeExecStats* node_stats_pb);
+  void Save(const string& device, NodeExecStatsWrapper* node_stats);
 
+  NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override;
   string ReportAllocsOnResourceExhausted(const string& err) override;
 
   // The following 2 Finalize methods populate the StepStats passed
@@ -176,20 +177,22 @@
   // User shouldn't call Save() methods after Finalize.
   void Finalize();
   // swaps the content of StepStats* from constructor with 'ss'.
-  void FinalizeAndSwap(StepStats* ss);
+  void FinalizeAndSwap(StepStats* step_stats);
 
  private:
+  // TODO(suharshs): Make this configurable if its not possible to find a value
+  // that works for all cases.
+  static const uint64 kMaxCollectedNodes = 1 << 20;
+
+  typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeStatsVector;
+
   void FinalizeInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_);
 
-  typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeExecStatsVec;
-  // TODO(suharshs): Make this configurable if its not possible to find a value
-  //                 that works for all cases.
-  const uint64 kMaxCollectedNodes = 1 << 20;
   mutex mu_;
   bool finalized_ GUARDED_BY(mu_);
-  std::unordered_map<string, NodeExecStatsVec> dev_stats_ GUARDED_BY(mu_);
+  std::unordered_map<string, NodeStatsVector> dev_stats_ GUARDED_BY(mu_);
   StepStats* step_stats_ GUARDED_BY(mu_);
-  uint64 collectedNodes GUARDED_BY(mu_) = 0;
+  uint64 collected_nodes_ GUARDED_BY(mu_) = 0;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/visitable_allocator.h b/tensorflow/core/common_runtime/visitable_allocator.h
deleted file mode 100644
index ae0563a..0000000
--- a/tensorflow/core/common_runtime/visitable_allocator.h
+++ /dev/null
@@ -1,79 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
-
-#include <functional>
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/tracking_allocator.h"
-
-namespace tensorflow {
-
-// Subclass VisitableAllocator instead of Allocator when a memory
-// allocator needs to enable some kind of registration/deregistration
-// of memory areas.
-class VisitableAllocator : public Allocator {
- public:
-  // Visitor gets called with a pointer to a memory area and its
-  // size in bytes.
-  typedef std::function<void(void*, size_t)> Visitor;
-
-  // Register a visitor guaranteed to be called exactly once on each
-  // chunk of memory newly allocated from the underlying device.
-  // Typically, chunks will be reused and possibly sub-divided by a
-  // pool manager, so the calls will happen only once per process
-  // execution, not once per tensor (re)allocation.
-  virtual void AddAllocVisitor(Visitor visitor) = 0;
-
-  // Register a visitor guaranteed to be called on each chunk of
-  // memory returned to the underlying device.
-  virtual void AddFreeVisitor(Visitor visitor) = 0;
-};
-
-// Needed for cases when a VisitableAllocator gets wrapped for tracking.
-// Multiple-inheritance is considered acceptable in this case because
-// VisitableAllocator is a pure virtual interface and only TrackingAllocator
-// has default implementation.
-class TrackingVisitableAllocator : public TrackingAllocator,
-                                   public VisitableAllocator {
- public:
-  TrackingVisitableAllocator(VisitableAllocator* allocator, bool track_ids)
-      : TrackingAllocator(allocator, track_ids), allocator_(allocator) {}
-  ~TrackingVisitableAllocator() override {}
-
-  string Name() override { return TrackingAllocator::Name(); }
-
-  void* AllocateRaw(size_t alignment, size_t num_bytes) override {
-    return TrackingAllocator::AllocateRaw(alignment, num_bytes);
-  }
-
-  void DeallocateRaw(void* ptr) override {
-    TrackingAllocator::DeallocateRaw(ptr);
-  }
-
-  void AddAllocVisitor(Visitor visitor) override {
-    allocator_->AddAllocVisitor(visitor);
-  }
-
-  void AddFreeVisitor(Visitor visitor) override {
-    allocator_->AddFreeVisitor(visitor);
-  }
-
- protected:
-  VisitableAllocator* allocator_;
-};
-}  // namespace tensorflow
-#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 6c14603..f7a2967 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -233,14 +233,11 @@
     params.function_library = lib;
     params.create_kernel = [session, lib, opseg](const NodeDef& ndef,
                                                  OpKernel** kernel) {
-      // We do not share the kernel via the OpSegment if the node is
-      // stateless, or a function.
       // NOTE(mrry): We must not share function kernels (implemented
       // using `CallOp`) between subgraphs, because `CallOp::handle_`
       // is tied to a particular subgraph. Even if the function itself
       // is stateful, the `CallOp` that invokes it is not.
-      if (!lib->IsStateful(ndef.op()) ||
-          lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
+      if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
         return lib->CreateKernel(ndef, kernel);
       }
       auto create_fn = [lib, &ndef](OpKernel** kernel) {
@@ -252,8 +249,7 @@
       return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn);
     };
     params.delete_kernel = [lib](OpKernel* kernel) {
-      // If the node is stateful, opseg owns it. Otherwise, delete it.
-      if (kernel && !lib->IsStateful(kernel->type_string())) {
+      if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
         delete kernel;
       }
     };
diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h
index ec93b9a..016d1a9 100644
--- a/tensorflow/core/example/feature_util.h
+++ b/tensorflow/core/example/feature_util.h
@@ -103,6 +103,7 @@
 #include <iterator>
 #include <type_traits>
 
+#include "absl/base/macros.h"
 #include "tensorflow/core/example/example.pb.h"
 #include "tensorflow/core/example/feature.pb.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
@@ -113,10 +114,10 @@
 
 namespace internal {
 
-// DEPRECATED: Use GetFeature instead.
 // TODO(gorban): Update all clients in a followup CL.
 // Returns a reference to a feature corresponding to the name.
 // Note: it will create a new Feature if it is missing in the example.
+ABSL_DEPRECATED("Use GetFeature instead.")
 Feature& ExampleFeature(const string& name, Example* example);
 
 // Specializations of RepeatedFieldTrait define a type of RepeatedField
@@ -314,9 +315,9 @@
   return HasFeature<FeatureType...>(key, GetFeatures(example));
 }
 
-// DEPRECATED: use HasFeature instead.
 // TODO(gorban): update all clients in a followup CL.
 template <typename... FeatureType>
+ABSL_DEPRECATED("Use HasFeature instead.")
 bool ExampleHasFeature(const string& key, const Example& example) {
   return HasFeature<FeatureType...>(key, example);
 }
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
index 2a7ee16..84cee55 100644
--- a/tensorflow/core/framework/allocator.cc
+++ b/tensorflow/core/framework/allocator.cc
@@ -196,7 +196,7 @@
   class CPUSubAllocator : public SubAllocator {
    public:
     explicit CPUSubAllocator(CPUAllocator* cpu_allocator)
-        : cpu_allocator_(cpu_allocator) {}
+        : SubAllocator({}, {}), cpu_allocator_(cpu_allocator) {}
 
     void* Alloc(size_t alignment, size_t num_bytes) override {
       return cpu_allocator_->AllocateRaw(alignment, num_bytes);
@@ -222,4 +222,22 @@
   }
   return cpu_alloc;
 }
+
+SubAllocator::SubAllocator(const std::vector<Visitor>& alloc_visitors,
+                           const std::vector<Visitor>& free_visitors)
+    : alloc_visitors_(alloc_visitors), free_visitors_(free_visitors) {}
+
+void SubAllocator::VisitAlloc(void* ptr, int index, size_t num_bytes) {
+  for (const auto& v : alloc_visitors_) {
+    v(ptr, index, num_bytes);
+  }
+}
+
+void SubAllocator::VisitFree(void* ptr, int index, size_t num_bytes) {
+  // Although we don't guarantee any order of visitor application, strive
+  // to apply free visitors in reverse order of alloc visitors.
+  for (int i = free_visitors_.size() - 1; i >= 0; --i) {
+    free_visitors_[i](ptr, index, num_bytes);
+  }
+}
 }  // namespace tensorflow
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index ded120b..8c23604 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -24,6 +24,7 @@
 #include "tensorflow/core/framework/resource_handle.h"
 #include "tensorflow/core/framework/type_traits.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
@@ -387,13 +388,36 @@
 // full statistics. By default, it's disabled.
 void EnableCPUAllocatorFullStats(bool enable);
 
-// Abstract interface of an object that does the underlying suballoc/free of
-// memory for a higher-level allocator.
+// An object that does the underlying suballoc/free of memory for a higher-level
+// allocator.  The expectation is that the higher-level allocator is doing some
+// kind of cache or pool management so that it will call SubAllocator::Alloc and
+// Free relatively infrequently, compared to the number of times its own
+// AllocateRaw and Free methods are called.
 class SubAllocator {
  public:
+  // Visitor gets called with a pointer to a memory area and its
+  // size in bytes.  The index value will be numa_node for a CPU
+  // allocator and GPU id for a GPU allocator.
+  typedef std::function<void(void*, int index, size_t)> Visitor;
+
+  SubAllocator(const std::vector<Visitor>& alloc_visitors,
+               const std::vector<Visitor>& free_visitors);
+
   virtual ~SubAllocator() {}
   virtual void* Alloc(size_t alignment, size_t num_bytes) = 0;
   virtual void Free(void* ptr, size_t num_bytes) = 0;
+
+ protected:
+  // Implementation of Alloc() method must call this on newly allocated
+  // value.
+  void VisitAlloc(void* ptr, int index, size_t num_bytes);
+
+  // Implementation of Free() method must call this on value to be
+  // freed immediately before deallocation.
+  void VisitFree(void* ptr, int index, size_t num_bytes);
+
+  const std::vector<Visitor> alloc_visitors_;
+  const std::vector<Visitor> free_visitors_;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc
index 1258e40..af59500 100644
--- a/tensorflow/core/framework/cancellation.cc
+++ b/tensorflow/core/framework/cancellation.cc
@@ -89,6 +89,16 @@
   }
 }
 
+bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
+  mutex_lock lock(mu_);
+  if (is_cancelled_ || is_cancelling_) {
+    return false;
+  } else {
+    callbacks_.erase(token);
+    return true;
+  }
+}
+
 CancellationManager::~CancellationManager() {
   if (!callbacks_.empty()) {
     StartCancel();
diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h
index acdaaf6..7a5d942 100644
--- a/tensorflow/core/framework/cancellation.h
+++ b/tensorflow/core/framework/cancellation.h
@@ -122,6 +122,15 @@
   // cancellation manager.
   bool DeregisterCallback(CancellationToken token);
 
+  // Deregister the callback that, when registered, was associated
+  // with the given cancellation token. Returns true iff the callback
+  // was deregistered and will not be invoked; otherwise returns false
+  // immediately, with no guarantee that the callback has completed.
+  //
+  // This method is guaranteed to return true if StartCancel has not been
+  // called.
+  bool TryDeregisterCallback(CancellationToken token);
+
  private:
   bool is_cancelling_;
   std::atomic_bool is_cancelled_;
diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc
index e3f1824..bf7593b 100644
--- a/tensorflow/core/framework/cancellation_test.cc
+++ b/tensorflow/core/framework/cancellation_test.cc
@@ -115,4 +115,56 @@
   delete cm;
 }
 
+TEST(Cancellation, TryDeregisterWithoutCancel) {
+  bool is_cancelled = false;
+  CancellationManager* manager = new CancellationManager();
+  auto token = manager->get_cancellation_token();
+  bool registered = manager->RegisterCallback(
+      token, [&is_cancelled]() { is_cancelled = true; });
+  EXPECT_TRUE(registered);
+  bool deregistered = manager->TryDeregisterCallback(token);
+  EXPECT_TRUE(deregistered);
+  delete manager;
+  EXPECT_FALSE(is_cancelled);
+}
+
+TEST(Cancellation, TryDeregisterAfterCancel) {
+  bool is_cancelled = false;
+  CancellationManager* manager = new CancellationManager();
+  auto token = manager->get_cancellation_token();
+  bool registered = manager->RegisterCallback(
+      token, [&is_cancelled]() { is_cancelled = true; });
+  EXPECT_TRUE(registered);
+  manager->StartCancel();
+  EXPECT_TRUE(is_cancelled);
+  bool deregistered = manager->TryDeregisterCallback(token);
+  EXPECT_FALSE(deregistered);
+  delete manager;
+}
+
+TEST(Cancellation, TryDeregisterDuringCancel) {
+  Notification cancel_started, finish_callback, cancel_complete;
+  CancellationManager* manager = new CancellationManager();
+  auto token = manager->get_cancellation_token();
+  bool registered = manager->RegisterCallback(token, [&]() {
+    cancel_started.Notify();
+    finish_callback.WaitForNotification();
+  });
+  EXPECT_TRUE(registered);
+
+  thread::ThreadPool w(Env::Default(), "test", 1);
+  w.Schedule([&]() {
+    manager->StartCancel();
+    cancel_complete.Notify();
+  });
+  cancel_started.WaitForNotification();
+
+  bool deregistered = manager->TryDeregisterCallback(token);
+  EXPECT_FALSE(deregistered);
+
+  finish_callback.Notify();
+  cancel_complete.WaitForNotification();
+  delete manager;
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index 5281c56..284dafb 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -20,7 +20,6 @@
 
 namespace tensorflow {
 namespace data {
-
 namespace {
 
 // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 4ee6749..697e060 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -47,6 +47,8 @@
 class Node;
 
 namespace data {
+// A constant that can be used to enable auto-tuning.
+constexpr int kAutoTune = -1;
 
 class DatasetBase;
 class SerializationContext;
@@ -527,25 +529,11 @@
                       std::unique_ptr<IteratorBase>* iterator) const {
     *iterator = MakeIteratorInternal(prefix);
     if (ctx->model()) {
-      // The prefix might contain an index. We need to strip it to make it
-      // possible for the model to successfully identify the output node.
-      string sanitized_prefix = prefix;
-      if (str_util::EndsWith(prefix, "]")) {
-        sanitized_prefix = prefix.substr(0, prefix.rfind('['));
-      }
-      std::shared_ptr<model::Node> node =
-          ctx->model()->AddNode((*iterator)->prefix(), sanitized_prefix);
-      std::vector<string> tokens =
-          str_util::Split((*iterator)->prefix(), ':', str_util::SkipEmpty());
-      node->set_name(tokens[tokens.size() - 1]);
+      ctx->model()->AddNode((*iterator)->prefix(), prefix);
       std::shared_ptr<model::Model> model = ctx->model();
       const string& prefix = (*iterator)->prefix();
-      (*iterator)->AddCleanupFunction([model, node, prefix]() {
-        if (node->output()) {
-          node->output()->remove_input(node);
-        }
-        model->RemoveNode(prefix);
-      });
+      (*iterator)->AddCleanupFunction(
+          [model, prefix]() { model->RemoveNode(prefix); });
     }
     return (*iterator)->Initialize(ctx);
   }
@@ -627,23 +615,10 @@
   Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
                  bool* end_of_sequence) final {
     tracing::ScopedActivity activity(params_.prefix);
-    Status s;
-    if (ctx->model()) {
-      std::shared_ptr<model::Node> node =
-          ctx->model()->LookupNode(params_.prefix);
-      if (node->output()) {
-        node->output()->stop_work();
-      }
-      node->start_work();
-      s = GetNextInternal(ctx, out_tensors, end_of_sequence);
-      node->stop_work();
-      node->add_element();
-      if (node->output()) {
-        node->output()->start_work();
-      }
-    } else {
-      s = GetNextInternal(ctx, out_tensors, end_of_sequence);
-    }
+    RecordStart(ctx, true /* stop_output */);
+    Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+    if (s.ok() && !*end_of_sequence) RecordElement(ctx);
+    RecordStop(ctx, true /* start_output */);
     if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
       s = errors::Internal(
           "Iterator \"", params_.prefix,
@@ -670,36 +645,51 @@
     return strings::StrCat(params_.prefix, ":", name);
   }
 
-  // When performance modeling is enabled, this method sets metadata entry for
-  // the model node corresponding to this iterator.
-  void SetMetadata(IteratorContext* ctx, const string& key, int64 value) {
+  // When performance modeling is enabled, this method adds a constant parameter
+  // to the model node corresponding to this iterator.
+  void AddConstantParameter(IteratorContext* ctx, const string& name,
+                            int64 value) {
     if (ctx->model()) {
-      std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
-      if (node) {
-        node->set_metadata(key, value);
-      }
+      ctx->model()->AddConstantParameter(prefix(), name, value);
+    }
+  }
+
+  // When performance modeling is enabled, this method adds a tunable parameter
+  // to the model node corresponding to this iterator.
+  //
+  // The performance modeling logic may use `value` to set the value of the
+  // tunable parameter at any point during the lifetime of this iterator. When
+  // it does, it notifies `cond_var`.
+  void AddTunableParameter(IteratorContext* ctx, const string& name,
+                           std::atomic<int64>* value, int64 min, int64 max,
+                           condition_variable* cond_var) {
+    if (ctx->model()) {
+      ctx->model()->AddTunableParameter(prefix(), name, value, min, max,
+                                        cond_var);
+    }
+  }
+
+  // When performance modeling is enabled, this method records the fact that
+  // this iterator has produced an element.
+  void RecordElement(IteratorContext* ctx) {
+    if (ctx->model()) {
+      ctx->model()->RecordElement(prefix());
     }
   }
 
   // When performance modeling is enabled, this method records the fact that
   // a thread of this iterator has started work.
-  void StartWork(IteratorContext* ctx) {
+  void RecordStart(IteratorContext* ctx, bool stop_output = false) {
     if (ctx->model()) {
-      std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
-      if (node) {
-        node->start_work();
-      }
+      ctx->model()->RecordStart(prefix(), stop_output);
     }
   }
 
   // When performance modeling is enabled, this method records the fact that
   // a thread of this iterator has stopped work.
-  void StopWork(IteratorContext* ctx) {
+  void RecordStop(IteratorContext* ctx, bool start_output = false) {
     if (ctx->model()) {
-      std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
-      if (node) {
-        node->stop_work();
-      }
+      ctx->model()->RecordStop(prefix(), start_output);
     }
   }
 
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 794250a..446c31b 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -20,6 +20,7 @@
 #include <string>
 #include <vector>
 
+#include "absl/base/macros.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/refcount.h"
@@ -176,9 +177,9 @@
     return nullptr;
   }
 
-  // DEPRECATED: Use `this->GetAllocator()` or `this->GetScopedAllocator()`.
   // This method is provided for backwards compatibility, and will be removed
   // in a future release.
+  ABSL_DEPRECATED("Use `this->GetAllocator()` or `this->GetScopedAllocator()`.")
   Allocator* GetStepAllocator(AllocatorAttributes attr, ResourceMgr*) {
     return GetAllocator(attr);
   }
@@ -214,10 +215,12 @@
 
   // This is overridden by GPU devices to reinitialize the derived
   // type returned by MakeGpuDevice.
-  virtual void ReinitializeGpuDevice(OpKernelContext* /*context*/,
-                                     PerOpGpuDevice* /*device*/,
-                                     DeviceContext* /*dc*/,
-                                     Allocator* /*allocator*/) {}
+  virtual Status ReinitializeGpuDevice(OpKernelContext* /*context*/,
+                                       PerOpGpuDevice* /*device*/,
+                                       DeviceContext* /*dc*/,
+                                       Allocator* /*allocator*/) {
+    return Status::OK();
+  }
 
   // Unimplemented by default
   virtual const DeviceAttributes& attributes() const;
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index d979353d..a17959a 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1294,6 +1294,18 @@
   for (const auto& r : ret_def) {
     fdef.mutable_ret()->insert({r.first, r.second});
   }
+
+  auto* op_def_registry = OpRegistry::Global();
+  // Check if any op is stateful.
+  for (const auto& n : node_def) {
+    const OpDef* op_def = nullptr;
+    auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
+    // Lookup can fail if e.g. we are calling a function that was not yet
+    // defined.  If it happens, conservatively assume the op is stateful.
+    if (!status.ok() || op_def->is_stateful()) {
+      fdef.mutable_signature()->set_is_stateful(true);
+    }
+  }
   return fdef;
 }
 
@@ -1355,6 +1367,7 @@
             strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
       }
     }
+    if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
   }
 
   // Returns
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index c5a4f66..d5c203d 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -91,6 +91,40 @@
       });
 }
 
+FunctionDef RandomUniform() {
+  const Tensor kZero = test::AsScalar<int64>(0);
+  const Tensor kTen = test::AsScalar<int64>(10);
+
+  return FDH::Define(
+      // Name
+      "RandomUniform",
+      // Args
+      {"x: T"},
+      // Return values
+      {"random_uniform: int64"},
+      // Attr def
+      {"T:{float, double, int32, int64, string}"},
+      {{{"random_uniform/shape"},
+        "Const",
+        {},
+        {{"value", kZero}, {"dtype", DT_INT64}}},
+       {{"random_uniform/min"},
+        "Const",
+        {},
+        {{"value", kZero}, {"dtype", DT_INT64}}},
+       {{"random_uniform/max"},
+        "Const",
+        {},
+        {{"value", kTen}, {"dtype", DT_INT64}}},
+       {{"random_uniform"},
+        "RandomUniformInt",
+        {},
+        {{"T", DT_INT64},
+         {"Tout", DT_INT64},
+         {"seed", 87654321},
+         {"seed2", 42}}}});
+}
+
 FunctionDef XTimesTwo() {
   const Tensor kTwo = test::AsScalar<int64>(2);
   return FDH::Define(
diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
index ad61a76..a017434 100644
--- a/tensorflow/core/framework/function_testlib.h
+++ b/tensorflow/core/framework/function_testlib.h
@@ -84,6 +84,9 @@
 // x: T -> bool.
 FunctionDef IsZero();
 
+// x: T -> int64
+FunctionDef RandomUniform();
+
 // x:T, y:T -> y:T, x:T
 FunctionDef Swap();
 
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
index 250b006..b0330ec 100644
--- a/tensorflow/core/framework/model.cc
+++ b/tensorflow/core/framework/model.cc
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -15,52 +15,26 @@
 
 #include "tensorflow/core/framework/model.h"
 
+#include <memory>
+
 namespace tensorflow {
 namespace data {
 namespace model {
 
 // TODO(jsimsa): Use `Node` subclassing instead of types and node statements.
-void Node::CollectKnobs(std::vector<Node::Knob>* knobs) {
-  mutex_lock l(mu_);
+void Model::Node::CollectTunables(
+    std::vector<std::shared_ptr<Node::Tunable>>* tunables) {
+  tf_shared_lock l(mu_);
+  for (auto input : inputs_) {
+    input->CollectTunables(tunables);
+  }
   switch (type_) {
-    case Type::PARALLEL_INTERLEAVE_V2: {
-      for (auto input : inputs_) {
-        input->CollectKnobs(knobs);
-      }
-      int64 processing_time = static_cast<int64>(
-          static_cast<double>(ProcessingTimeLocked() -
-                              inputs_.front()->ProcessingTime()) /
-          static_cast<double>(inputs_.size() - 1));
-      knobs->emplace_back(
-          Node::Knob{this, processing_time, metadata_["parallelism"]});
-      return;
-    }
     case Type::MAP_AND_BATCH:
+    case Type::PARALLEL_INTERLEAVE_V2:
     case Type::PARALLEL_MAP: {
-      for (auto input : inputs_) {
-        input->CollectKnobs(knobs);
-      }
-      knobs->emplace_back(
-          Node::Knob{this, NanosPerElementLocked(), metadata_["parallelism"]});
-      return;
-    }
-    case Type::BATCH:
-    case Type::CACHE:
-    case Type::CONCATENATE:
-    case Type::FILTER:
-    case Type::FLAT_MAP:
-    case Type::INTERLEAVE:
-    case Type::MAP:
-    case Type::PADDED_BATCH:
-    case Type::PARALLEL_INTERLEAVE:
-    case Type::PREFETCH:
-    case Type::REPEAT:
-    case Type::SHUFFLE:
-    case Type::SKIP:
-    case Type::TAKE:
-    case Type::ZIP: {
-      for (auto input : inputs_) {
-        input->CollectKnobs(knobs);
+      if (auto* tunable_param =
+              gtl::FindOrNull(tunable_params_, "parallelism")) {
+        tunables->push_back(*tunable_param);
       }
       return;
     }
@@ -69,12 +43,19 @@
   }
 }
 
-int64 Node::ProcessingTimeLocked() {
+int64 Model::Node::GetParameterValue(const string& name) {
+  if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) {
+    return (*tunable_param)->value;
+  }
+  return constant_params_[name];
+}
+
+int64 Model::Node::ProcessingTimeLocked() {
   switch (type_) {
     case Type::BATCH:
     case Type::MAP_AND_BATCH:
     case Type::PADDED_BATCH: {
-      int64 batch_size = metadata_["batch_size"];
+      int64 batch_size = GetParameterValue("batch_size");
       return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs();
     }
     case Type::FILTER: {
@@ -118,11 +99,11 @@
   }
 }
 
-int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
+int64 Model::Node::OutputTimeLocked(std::vector<int64>* input_times) {
   switch (type_) {
     case Type::BATCH:
     case Type::PADDED_BATCH: {
-      double batch_size = metadata_["batch_size"];
+      double batch_size = GetParameterValue("batch_size");
       int64 old_value = (*input_times)[input_times->size() - 1];
       (*input_times)[input_times->size() - 1] = static_cast<int64>(
           static_cast<double>(old_value + NanosPerElementLocked()) /
@@ -168,8 +149,8 @@
                  static_cast<double>(inputs_.size() - 1);
     }
     case Type::MAP_AND_BATCH: {
-      double batch_size = metadata_["batch_size"];
-      double parallelism = metadata_["parallelism"];
+      double batch_size = GetParameterValue("batch_size");
+      double parallelism = GetParameterValue("parallelism");
       int64 delta =
           static_cast<int64>(static_cast<double>(NanosPerElementLocked()) /
                              (batch_size * parallelism));
@@ -182,22 +163,41 @@
       return std::max(0LL,
                       output_time - input_times->at(input_times->size() - 2));
     }
-    case Type::PARALLEL_INTERLEAVE:
-    case Type::PARALLEL_INTERLEAVE_V2: {
+    case Type::PARALLEL_INTERLEAVE: {
       // TODO(jsimsa): model the first input
       if (inputs_.size() <= 1) {
         return NanosPerElementLocked();
       }
-      int64 delta =
-          static_cast<int64>(static_cast<double>(NanosPerElementLocked()) *
-                             static_cast<double>(inputs_.size() - 1));
+      int64 delta = static_cast<double>(NanosPerElementLocked()) *
+                    static_cast<double>(inputs_.size() - 1);
       input_times->push_back(delta);
       auto cleanup =
           gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
       int64 inputs_output_time = OutputTimeForInputs(input_times) -
                                  inputs_.front()->OutputTime(input_times);
-      double parallelism = std::min(port::NumSchedulableCPUs(),
-                                    static_cast<int>(metadata_["parallelism"]));
+      double parallelism = GetParameterValue("parallelism");
+      int64 output_time =
+          NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
+                                      static_cast<double>(inputs_.size() - 1)) /
+                                     parallelism);
+      return std::max(0LL,
+                      output_time - input_times->at(input_times->size() - 2));
+    }
+    case Type::PARALLEL_INTERLEAVE_V2: {
+      // TODO(jsimsa): model the first input
+      if (inputs_.size() <= 1) {
+        return NanosPerElementLocked();
+      }
+      int64 delta = static_cast<double>(NanosPerElementLocked()) *
+                    static_cast<double>(inputs_.size() - 1);
+      input_times->push_back(delta);
+      auto cleanup =
+          gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+      int64 inputs_output_time = OutputTimeForInputs(input_times) -
+                                 inputs_.front()->OutputTime(input_times);
+      double parallelism =
+          std::min(static_cast<int>(GetParameterValue("cycle_length")),
+                   static_cast<int>(GetParameterValue("parallelism")));
       int64 output_time =
           NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
                                       static_cast<double>(inputs_.size() - 1)) /
@@ -206,8 +206,9 @@
                       output_time - input_times->at(input_times->size() - 2));
     }
     case Type::PARALLEL_MAP: {
-      double parallelism = std::min(port::NumSchedulableCPUs(),
-                                    static_cast<int>(metadata_["parallelism"]));
+      double parallelism =
+          std::min(port::NumSchedulableCPUs(),
+                   static_cast<int>(GetParameterValue("parallelism")));
       int64 delta = static_cast<int64>(
           static_cast<double>(NanosPerElementLocked()) / parallelism);
       input_times->push_back(delta);
@@ -248,32 +249,34 @@
   }
 }
 
-Model::Model(const proto::Model& model_proto) {
-  id_counter_ = model_proto.id_counter();
-  std::map<int64, std::shared_ptr<Node>> lookup_table;
-  for (auto node_proto : model_proto.node()) {
-    std::shared_ptr<Node> node(new Node(node_proto));
-    lookup_table[node_proto.id()] = node;
+void Model::AddConstantParameter(const string& node_name,
+                                 const string& parameter_name, int64 value) {
+  tf_shared_lock l(mu_);
+  auto node = gtl::FindOrNull(lookup_table_, node_name);
+  if (node) {
+    (*node)->add_constant_param(parameter_name, value);
   }
-  for (auto node_proto : model_proto.node()) {
-    std::shared_ptr<Node> node = lookup_table[node_proto.id()];
-    for (int64 id : node_proto.input()) {
-      node->add_input(lookup_table[id]);
-    }
-    node->set_output(lookup_table[node_proto.output()]);
-  }
-  output_ = lookup_table[model_proto.output()];
 }
 
-std::shared_ptr<Node> Model::AddNode(const string& name,
-                                     const string& output_name) {
-  mutex_lock l(mu_);
+void Model::AddNode(const string& name, const string& output_name) {
+  // The name captures the sequence of iterators joined by `::`. We use the full
+  // sequence as the key in the lookup table, but only the last element of the
+  // sequence as the name node.
+  std::vector<string> tokens =
+      str_util::Split(name, ':', str_util::SkipEmpty());
+  // The output name might contain an index. We need to strip it to make it
+  // possible for the model to successfully identify the output node.
+  string sanitized_output_name = output_name;
+  if (str_util::EndsWith(output_name, "]")) {
+    sanitized_output_name = output_name.substr(0, output_name.rfind('['));
+  }
   std::shared_ptr<Node> output;
-  auto it = lookup_table_.find(output_name);
+  mutex_lock l(mu_);
+  auto it = lookup_table_.find(sanitized_output_name);
   if (it != lookup_table_.end()) {
     output = it->second;
   }
-  std::shared_ptr<Node> node(new Node(id_counter_++, output));
+  std::shared_ptr<Node> node(new Node(id_counter_++, tokens.back(), output));
   if (!output_) {
     output_ = node;
   }
@@ -281,107 +284,127 @@
     output->add_input(node);
   }
   lookup_table_.insert(std::make_pair(name, node));
-  return node;
 }
 
-std::shared_ptr<Node> Model::LookupNode(const string& name) {
+void Model::AddProcessingTime(const string& name, int64 delta) {
   tf_shared_lock l(mu_);
-  std::shared_ptr<Node> result;
-  auto it = lookup_table_.find(name);
-  if (it != lookup_table_.end()) {
-    result = it->second;
+  auto node = gtl::FindOrNull(lookup_table_, name);
+  if (node) {
+    (*node)->add_processing_time(delta);
   }
-  return result;
 }
 
-void Model::Optimize() {
-  mutex_lock l(mu_);
-  int64 processing_time = ProcessingTime();
-  int64 num_cpus = port::NumSchedulableCPUs();
-  std::vector<Node::Knob> knobs = CollectKnobs();
-  // The optimization algorithm starts by setting all parallelism knobs to 1. It
-  // then repeatedly identifies the knob that, when turned up by 1, decreases
-  // the output time the most. This process is repeated until all knobs reach
-  // the number of schedulable CPUs or the projected output time is less than or
-  // equal to the processing time needed to produce an element divided by the
-  // number of schedulable CPUs.
-  for (auto& knob : knobs) {
-    LOG(INFO) << knob.node->name() << " " << knob.processing_time;
-    knob.value = 1;
-    knob.node->set_metadata("parallelism", knob.value);
+void Model::AddTunableParameter(const string& node_name,
+                                const string& parameter_name,
+                                std::atomic<int64>* value, int64 min, int64 max,
+                                condition_variable* cond_var) {
+  tf_shared_lock l(mu_);
+  auto node = *gtl::FindOrNull(lookup_table_, node_name);
+  DCHECK(node);
+  node->add_tunable_param(parameter_name, value, min, max, cond_var);
+}
+
+// The optimization algorithm starts by setting all tunable parallelism
+// parameters to 1. It then repeatedly identifies the parameter whose increase
+// in parallelism decreases the output time the most. This process is repeated
+// until all parameters reach their maximum values or the projected output time
+// is less than or equal to the processing time needed to produce an element
+// divided by CPU budget.
+void Model::Optimize(int64 cpu_budget) {
+  tf_shared_lock lock(mu_);
+  std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
+  const int64 processing_time = ProcessingTime();
+  tunables = CollectTunables();
+  for (auto tunable : tunables) {
+    tunable->value = 1;
   }
   while (true) {
-    int64 output_time = OutputTime();
-    bool all_knobs = true;
-    for (auto knob : knobs) {
-      if (knob.value < num_cpus) {
-        all_knobs = false;
+    const int64 output_time = OutputTime();
+    bool all_tunables = true;
+    for (auto& tunable : tunables) {
+      if (tunable->value < tunable->max) {
+        all_tunables = false;
         break;
       }
     }
-    if (output_time < processing_time / num_cpus || all_knobs) {
+    if (output_time < processing_time / cpu_budget || all_tunables) {
       break;
     }
     int64 best_delta = -1;
-    int best_knob = -1;
-    for (int i = 0; i < knobs.size(); ++i) {
-      if (knobs[i].value == num_cpus) {
+    Model::Node::Tunable* best_tunable = nullptr;
+    for (auto& tunable : tunables) {
+      if (tunable->value == tunable->max) {
         continue;
       }
-      knobs[i].node->set_metadata("parallelism", knobs[i].value + 1);
+      tunable->value++;
       int64 delta = output_time - OutputTime();
       if (delta > best_delta) {
         best_delta = delta;
-        best_knob = i;
+        best_tunable = tunable.get();
       }
-      knobs[i].node->set_metadata("parallelism", knobs[i].value);
+      tunable->value--;
     }
-    knobs[best_knob].value++;
-    knobs[best_knob].node->set_metadata("parallelism", knobs[best_knob].value);
+    if (!best_tunable) {
+      // NOTE: This can happen because we are performing the optimization
+      // while the model data is changing. If this becomes an issue, we should
+      // look into performing the optimization using a model snapshot.
+      break;
+    }
+    best_tunable->value++;
   }
-  for (auto knob : knobs) {
-    LOG(INFO) << knob.node->name() << " " << knob.value;
+  VLOG(2) << "Number of knobs: " << tunables.size();
+  for (auto& tunable : tunables) {
+    VLOG(2) << "Setting tunable parameter: " << tunable->value;
+    tunable->value_ptr->store(tunable->value);
+    if (tunable->cond_var) {
+      tunable->cond_var->notify_all();
+    }
   }
-  LOG(INFO) << "output time: " << OutputTime();
-  LOG(INFO) << "processing time: " << ProcessingTime();
 }
 
-void Model::OutputToFile() {
-  proto::Model model_proto;
-  ToProto(&model_proto);
-  string filename;
-  Env::Default()->LocalTempFilename(&filename);
-  TF_CHECK_OK(WriteStringToFile(Env::Default(), filename,
-                                model_proto.SerializeAsString()));
-  LOG(INFO) << filename;
+void Model::RecordElement(const string& name) {
+  tf_shared_lock l(mu_);
+  auto node = gtl::FindOrNull(lookup_table_, name);
+  if (node) {
+    (*node)->record_element();
+  }
 }
 
-void Model::RemoveNode(const string& prefix) {
+void Model::RecordStart(const string& name, bool stop_output) {
+  tf_shared_lock l(mu_);
+  auto node = gtl::FindOrNull(lookup_table_, name);
+  if (node) {
+    if (stop_output && (*node)->output()) {
+      (*node)->output()->record_stop();
+    }
+    (*node)->record_start();
+  }
+}
+
+void Model::RecordStop(const string& name, bool start_output) {
+  tf_shared_lock l(mu_);
+  auto node = gtl::FindOrNull(lookup_table_, name);
+  if (node) {
+    (*node)->record_stop();
+    if (start_output && (*node)->output()) {
+      (*node)->output()->record_start();
+    }
+  }
+}
+
+void Model::RemoveNode(const string& name) {
   mutex_lock l(mu_);
-  lookup_table_.erase(prefix);
-}
-
-void Model::ToProto(proto::Model* model_proto) {
-  mutex_lock l(mu_);
-  model_proto->set_id_counter(id_counter_);
-  model_proto->set_output(output_->id());
-  AddNodeToProto(output_, model_proto);
-}
-
-// static
-void Model::AddNodeToProto(const std::shared_ptr<Node>& node,
-                           proto::Model* model_proto) {
-  proto::Node* node_proto = model_proto->add_node();
-  node->ToProto(node_proto);
-  for (const std::shared_ptr<Node>& input : node->inputs()) {
-    AddNodeToProto(input, model_proto);
+  auto node = gtl::FindOrNull(lookup_table_, name);
+  if (node && (*node)->output()) {
+    (*node)->output()->remove_input(*node);
   }
+  lookup_table_.erase(name);
 }
 
-std::vector<Node::Knob> Model::CollectKnobs() {
-  std::vector<Node::Knob> knobs;
-  output_->CollectKnobs(&knobs);
-  return knobs;
+std::vector<std::shared_ptr<Model::Node::Tunable>> Model::CollectTunables() {
+  std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
+  output_->CollectTunables(&tunables);
+  return tunables;
 }
 
 int64 Model::OutputTime() {
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index 9817290..26402f5 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -22,9 +22,9 @@
 #include <utility>
 #include <vector>
 
-#include "tensorflow/core/framework/model.pb.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/platform/cpu_info.h"
 #include "tensorflow/core/platform/env.h"
@@ -33,356 +33,364 @@
 namespace data {
 namespace model {
 
-class Model;
-class Node;
-
-// Abstract representation of a TensorFlow input pipeline node. It collects
-// information about inputs to this node, processing time spent executing the
-// node logic, number of elements produced by the node, various other
-// information (e.g. batch size or execution parallelism).
-//
-// Developers of tf.data transformations are not expected to interact with this
-// class directly. Boiler plate code for creating the abstract representation of
-// the input pipeline and collecting common information has been added to the
-// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
-//
-// In addition, `DatasetBaseIterator` provides wrappers that can be used for
-// transformation-specific information collection. The `SetMetadata` wrapper can
-// be used to pass arbitrary metadata to the modeling framework, while the
-// `StartWork` and `StopWork` wrappers should be used to correctly account for
-// processing time of multi-threaded transformation that yield the CPU; such
-// transformations should invoke `StartWork()` when a transformation thread
-// starts executing (e.g. when created or woken up) and `StopWork()` when a
-// transformation thread stops executing (e.g. when returning or waiting).
-//
-// TODO(jsimsa): Create an API to capture the abstract semantics of each
-// tf.data transformation and replace switch-case blocks with inheritance.
-class Node {
- public:
-  Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {}
-
-  explicit Node(const proto::Node& node_proto) : id_(node_proto.id()) {
-    name_ = node_proto.name();
-    type_ = TypeFromName(node_proto.name());
-    processing_time_ = node_proto.processing_time();
-    num_elements_ = node_proto.num_elements();
-    metadata_.insert(node_proto.metadata().begin(),
-                     node_proto.metadata().end());
-  }
-
-  // Records that the node produced an element.
-  void add_element() LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    num_elements_++;
-  }
-
-  // Adds an input.
-  void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    inputs_.push_back(node);
-  }
-
-  // Increments the aggregate processing time by the given delta.
-  void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    processing_time_ += delta;
-  }
-
-  // Returns the unique node ID.
-  int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
-
-  // Returns the node inputs.
-  std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
-    tf_shared_lock l(mu_);
-    return inputs_;
-  }
-
-  // Returns the node name.
-  const string& name() LOCKS_EXCLUDED(mu_) {
-    tf_shared_lock l(mu_);
-    return name_;
-  }
-
-  // Returns the number of elements produced by the node.
-  int64 num_elements() LOCKS_EXCLUDED(mu_) {
-    tf_shared_lock l(mu_);
-    return num_elements_;
-  }
-
-  // Returns the node output.
-  std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
-    tf_shared_lock l(mu_);
-    return output_;
-  }
-
-  // Removes an input.
-  void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    inputs_.remove(input);
-  }
-
-  // Adds the given key-value pair to the node metadata.
-  void set_metadata(const string& key, int64 value) LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    metadata_[key] = value;
-  }
-
-  // Sets the node name.
-  void set_name(const string& name) LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    name_ = name;
-    type_ = TypeFromName(name);
-  }
-
-  // Set the node output.
-  void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    output_ = output;
-  }
-
-  // Records that a node thread has started work.
-  void start_work() LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
-  }
-
-  // Records that a node thread has stopped work.
-  void stop_work() LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    auto iter = work_start_.find(std::this_thread::get_id());
-    CHECK(work_start_.end() != iter)
-        << "Encountered a stop event that was not preceded by a start event.";
-    processing_time_ += Env::Default()->NowNanos() - iter->second;
-    work_start_.erase(iter);
-  }
-
- private:
-  // Represents a performance knob.
-  struct Knob {
-    Node* node;
-    int64 processing_time;
-    int64 value;
-  };
-
-  enum class Type {
-    BATCH = 0,
-    CACHE,
-    CONCATENATE,
-    FILTER,
-    FLAT_MAP,
-    INTERLEAVE,
-    MAP,
-    MAP_AND_BATCH,
-    PADDED_BATCH,
-    PARALLEL_INTERLEAVE,
-    PARALLEL_INTERLEAVE_V2,
-    PARALLEL_MAP,
-    PREFETCH,
-    REPEAT,
-    SHUFFLE,
-    SKIP,
-    TAKE,
-    ZIP,
-    UNKNOWN,
-  };
-
-  // Collects performance knobs in the subtree rooted in this node.
-  void CollectKnobs(std::vector<Node::Knob>* knobs) LOCKS_EXCLUDED(mu_);
-
-  // Returns the per-element processing time spent in this node.
-  int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    return NanosPerElementLocked();
-  }
-
-  int64 NanosPerElementLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-    if (num_elements_ == 0) {
-      return 0;
-    }
-    return (int64)((double)processing_time_ / (double)num_elements_);
-  }
-
-  // Returns the per-element output time for this node.
-  int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    return OutputTimeLocked(input_times);
-  }
-
-  int64 OutputTimeLocked(std::vector<int64>* input_times)
-      EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
-  int64 OutputTimeForInputs(std::vector<int64>* input_times)
-      EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-    int64 sum = 0;
-    for (auto input : inputs_) {
-      sum += input->OutputTime(input_times);
-    }
-    return sum;
-  }
-
-  // Returns the per-element processing time spent in the subtree rooted in this
-  // node.
-  int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    return ProcessingTimeLocked();
-  }
-
-  int64 ProcessingTimeLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
-  // Returns the per-element processing time spent in the inputs of this node.
-  int64 ProcessingTimeForInputs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-    int64 sum = 0;
-    for (auto input : inputs_) {
-      sum += input->ProcessingTimeLocked();
-    }
-    return sum;
-  }
-
-  // Serializes the node state into the given proto.
-  void ToProto(proto::Node* node_proto) LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
-    node_proto->set_id(id_);
-    node_proto->set_name(name_);
-    node_proto->set_num_elements(num_elements_);
-    node_proto->set_processing_time(processing_time_);
-    for (const std::shared_ptr<Node>& input : inputs_) {
-      node_proto->add_input(input->id());
-    }
-    if (output_) {
-      node_proto->set_output(output_->id());
-    }
-    node_proto->mutable_metadata()->insert(metadata_.begin(), metadata_.end());
-  }
-
-  Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-    if (name_ == "Batch") {
-      return Type::BATCH;
-    }
-    if (str_util::EndsWith(name_, "Cache")) {
-      return Type::CACHE;
-    }
-    if (name_ == "Concatenate") {
-      return Type::CONCATENATE;
-    }
-    if (name_ == "Filter") {
-      return Type::FILTER;
-    }
-    if (name_ == "FlatMap") {
-      return Type::FLAT_MAP;
-    }
-    if (name_ == "Interleave") {
-      return Type::INTERLEAVE;
-    }
-    if (name_ == "Map") {
-      return Type::MAP;
-    }
-    if (name_ == "MapAndBatch") {
-      return Type::MAP_AND_BATCH;
-    }
-    if (name_ == "PaddedBatch") {
-      return Type::PADDED_BATCH;
-    }
-    if (name_ == "ParallelInterleave") {
-      return Type::PARALLEL_INTERLEAVE;
-    }
-    if (name_ == "ParallelInterleaveV2") {
-      return Type::PARALLEL_INTERLEAVE_V2;
-    }
-    if (name_ == "ParallelMap") {
-      return Type::PARALLEL_MAP;
-    }
-    if (name_ == "Prefetch") {
-      return Type::PREFETCH;
-    }
-    if (str_util::EndsWith(name_, "Repeat")) {
-      return Type::REPEAT;
-    }
-    if (name_ == "Shuffle") {
-      return Type::SHUFFLE;
-    }
-    if (str_util::EndsWith(name_, "Skip")) {
-      return Type::SKIP;
-    }
-    if (str_util::EndsWith(name_, "Take")) {
-      return Type::TAKE;
-    }
-    if (name_ == "Zip") {
-      return Type::ZIP;
-    }
-    return Type::UNKNOWN;
-  }
-
-  mutex mu_;
-  const int64 id_;
-  Type type_ GUARDED_BY(mu_);
-  string name_ GUARDED_BY(mu_);
-  int64 processing_time_ GUARDED_BY(mu_) = 0;
-  int64 num_elements_ GUARDED_BY(mu_) = 0;
-  std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
-  std::map<string, int64> metadata_ GUARDED_BY(mu_);
-  std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
-  std::shared_ptr<Node> output_ GUARDED_BY(mu_);
-
-  friend class Model;
-};
-
 // Abstract representation of a TensorFlow input pipeline that can be used
 // for collecting runtime information and optimizing performance. It collects
 // runtime information about execution of the input pipeline that is used to
 // create a performance model, which is in turn used to identify optimal values
-// of performance knobs.
+// of tunable parameters.
 //
 // Developers of tf.data transformations are not expected to interact with this
 // class directly. Boiler plate code for creating the abstract representation of
 // the input pipeline and collecting runtime information has been added to the
 // implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
-//
-// TODO(jsimsa): Add a mechanism for feeding the result of the optimization
-// into the input pipeline.
 class Model {
  public:
   Model() = default;
-  explicit Model(const proto::Model& model_proto);
 
-  ~Model() {}
-
-  // Returns the model output node.
-  std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
-    tf_shared_lock l(mu_);
-    return output_;
-  }
-
-  // Adds a node with the given name and given output (identified by name).
-  std::shared_ptr<Node> AddNode(const string& name, const string& output_name)
+  // Adds a constant parameter for the given node.
+  void AddConstantParameter(const string& node_name,
+                            const string& parameter_name, int64 value)
       LOCKS_EXCLUDED(mu_);
 
-  // Looks up the node using the given name.
-  std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_);
+  // Adds a node with the given name and given output (identified by name).
+  void AddNode(const string& name, const string& output_name)
+      LOCKS_EXCLUDED(mu_);
+
+  // Increments the processing time for the given node..
+  void AddProcessingTime(const string& name, int64 delta) LOCKS_EXCLUDED(mu_);
+
+  // Adds a tunable parameter for the given node.
+  void AddTunableParameter(const string& node_name,
+                           const string& parameter_name,
+                           std::atomic<int64>* value, int64 min, int64 max,
+                           condition_variable* cond_var) LOCKS_EXCLUDED(mu_);
 
   // Runs optimization.
-  void Optimize() LOCKS_EXCLUDED(mu_);
+  void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_);
 
-  // Outputs the state of a model to a file.
-  //
-  // TODO(jsimsa): Remove this method once the optimization loop is closed.
-  void OutputToFile() LOCKS_EXCLUDED(mu_);
+  // Records that a node has produced an element.
+  void RecordElement(const string& name) LOCKS_EXCLUDED(mu_);
 
-  // Removes the node identified by the given name.
-  void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_);
+  // Records that the given node has started work. If `stop_output` is set, it
+  // also records that the output of the given node has stopped work.
+  void RecordStart(const string& name, bool stop_output) LOCKS_EXCLUDED(mu_);
 
-  // Serializes the model state to the given proto.
-  void ToProto(proto::Model* model_proto) LOCKS_EXCLUDED(mu_);
+  // Records that the given node has stopped work. If `stop_output` is set, it
+  // also records that the output of the given node has started work.
+  void RecordStop(const string& name, bool start_output) LOCKS_EXCLUDED(mu_);
+
+  // Removes the given node.
+  void RemoveNode(const string& name) LOCKS_EXCLUDED(mu_);
 
  private:
-  static void AddNodeToProto(const std::shared_ptr<Node>& node,
-                             proto::Model* model_proto);
+  // Abstract representation of a TensorFlow input pipeline node. It collects
+  // information about inputs to this node, processing time spent executing the
+  // node logic, number of elements produced by the node, various other
+  // information (e.g. batch size or execution parallelism).
+  //
+  // Developers of tf.data transformations are not expected to interact with
+  // this class directly. Boiler plate code for creating the abstract
+  // representation of the input pipeline and collecting common information has
+  // been added to the implementation of `DatasetBase` and `DatasetBaseIterator`
+  // respectively.
+  //
+  // In addition, `DatasetBaseIterator` provides wrappers that can be used for
+  // transformation-specific information collection. The `SetMetadata` wrapper
+  // can be used to pass arbitrary metadata to the modeling framework, while the
+  // `StartWork` and `StopWork` wrappers should be used to correctly account for
+  // processing time of multi-threaded transformation that yield the CPU; such
+  // transformations should invoke `StartWork()` when a transformation thread
+  // starts executing (e.g. when created or woken up) and `StopWork()` when a
+  // transformation thread stops executing (e.g. when returning or waiting).
+  //
+  // TODO(jsimsa): Create an API to capture the abstract semantics of each
+  // tf.data transformation and replace switch-case blocks with inheritance.
+  class Node {
+   public:
+    // Represents a tunable parameter.
+    struct Tunable {
+      Tunable(std::atomic<int64>* value, int64 min, int64 max,
+              condition_variable* cond_var)
+          : value(*value),
+            min(min),
+            max(max),
+            value_ptr(value),
+            cond_var(cond_var) {}
 
-  std::vector<Node::Knob> CollectKnobs() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+      // Identifies the model value of the parameter. This can be different from
+      // the actual value (e.g. during optimization search).
+      int64 value;
 
-  int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+      // Identifies the minimum value of the parameter.
+      int64 min;
 
-  int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+      // Identifies the maximum value of the parameter.
+      int64 max;
 
+      // Points to the actual value of the parameter. Not owned.
+      std::atomic<int64>* value_ptr;
+
+      // If non-null, this condition variable is notified when the model updates
+      // the actual value of the parameter (via `value_ptr`). Not owned.
+      condition_variable* cond_var;
+    };
+
+    Node(int64 id, const string& name, std::shared_ptr<Node> output)
+        : id_(id), name_(name), type_(TypeFromName(name)), output_(output) {}
+
+    // Adds a constant parameter.
+    void add_constant_param(const string& name, int64 value)
+        LOCKS_EXCLUDED(mu_) {
+      mutex_lock l(mu_);
+      constant_params_[name] = value;
+    }
+
+    // Adds an input.
+    void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
+      mutex_lock l(mu_);
+      inputs_.push_back(node);
+    }
+
+    // Increments the aggregate processing time by the given delta.
+    void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
+      mutex_lock l(mu_);
+      processing_time_ += delta;
+    }
+
+    // Adds a tunable parameter.
+    void add_tunable_param(const string& name, std::atomic<int64>* value,
+                           int64 min, int64 max, condition_variable* cond_var)
+        LOCKS_EXCLUDED(mu_) {
+      mutex_lock l(mu_);
+      tunable_params_[name] =
+          std::make_shared<Tunable>(value, min, max, cond_var);
+    }
+
+    // Returns the unique node ID.
+    int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
+
+    // Returns the node inputs.
+    std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
+      tf_shared_lock l(mu_);
+      return inputs_;
+    }
+
+    // Returns the node name.
+    const string& name() LOCKS_EXCLUDED(mu_) {
+      tf_shared_lock l(mu_);
+      return name_;
+    }
+
+    // Returns the number of elements produced by the node.
+    int64 num_elements() LOCKS_EXCLUDED(mu_) {
+      tf_shared_lock l(mu_);
+      return num_elements_;
+    }
+
+    // Returns the node output.
+    std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+      tf_shared_lock l(mu_);
+      return output_;
+    }
+
+    // Records that the node produced an element.
+    void record_element() LOCKS_EXCLUDED(mu_) {
+      mutex_lock l(mu_);
+      num_elements_++;
+    }
+
+    // Records that a node thread has started executing.
+    void record_start() LOCKS_EXCLUDED(mu_) {
+      mutex_lock l(mu_);
+      work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
+    }
+
+    // Records that a node thread has stopped executing.
+    void record_stop() LOCKS_EXCLUDED(mu_) {
+      mutex_lock l(mu_);
+      std::thread::id tid = std::this_thread::get_id();
+      auto start_time = gtl::FindOrNull(work_start_, tid);
+      DCHECK(start_time)
+          << "Encountered a stop event that was not preceded by a start event.";
+      if (start_time) {
+        processing_time_ += Env::Default()->NowNanos() - *start_time;
+        work_start_.erase(tid);
+      }
+    }
+
+    // Removes an input.
+    void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
+      mutex_lock l(mu_);
+      inputs_.remove(input);
+    }
+
+    // Set the node output.
+    void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
+      mutex_lock l(mu_);
+      output_ = output;
+    }
+
+    // Collects tunable parameters in the subtree rooted in this node.
+    void CollectTunables(std::vector<std::shared_ptr<Tunable>>* tunables)
+        LOCKS_EXCLUDED(mu_);
+
+    // Returns the per-element output time for this node.
+    int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
+      tf_shared_lock l(mu_);
+      return OutputTimeLocked(input_times);
+    }
+
+    // Returns the per-element processing time spent in the subtree rooted in
+    // this node.
+    int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
+      tf_shared_lock l(mu_);
+      return ProcessingTimeLocked();
+    }
+
+   private:
+    enum class Type {
+      BATCH = 0,
+      CACHE,
+      CONCATENATE,
+      FILTER,
+      FLAT_MAP,
+      INTERLEAVE,
+      MAP,
+      MAP_AND_BATCH,
+      PADDED_BATCH,
+      PARALLEL_INTERLEAVE,
+      PARALLEL_INTERLEAVE_V2,
+      PARALLEL_MAP,
+      PREFETCH,
+      REPEAT,
+      SHUFFLE,
+      SKIP,
+      TAKE,
+      ZIP,
+      UNKNOWN,
+    };
+
+    // Gets a value of the given parameter (tunable or constant).
+    int64 GetParameterValue(const string& name) SHARED_LOCKS_REQUIRED(mu_);
+
+    // Returns the per-element processing time spent in this node.
+    int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
+      tf_shared_lock l(mu_);
+      return NanosPerElementLocked();
+    }
+
+    int64 NanosPerElementLocked() SHARED_LOCKS_REQUIRED(mu_) {
+      if (num_elements_ == 0) {
+        return 0;
+      }
+      return (int64)((double)processing_time_ / (double)num_elements_);
+    }
+
+    int64 OutputTimeLocked(std::vector<int64>* input_times)
+        SHARED_LOCKS_REQUIRED(mu_);
+
+    int64 OutputTimeForInputs(std::vector<int64>* input_times)
+        SHARED_LOCKS_REQUIRED(mu_) {
+      int64 sum = 0;
+      for (auto input : inputs_) {
+        sum += input->OutputTime(input_times);
+      }
+      return sum;
+    }
+
+    int64 ProcessingTimeLocked() SHARED_LOCKS_REQUIRED(mu_);
+
+    // Returns the per-element processing time spent in the inputs of this node.
+    int64 ProcessingTimeForInputs() SHARED_LOCKS_REQUIRED(mu_) {
+      int64 sum = 0;
+      for (auto input : inputs_) {
+        sum += input->ProcessingTime();
+      }
+      return sum;
+    }
+
+    Type TypeFromName(const string& name) SHARED_LOCKS_REQUIRED(mu_) {
+      if (name_ == "Batch") {
+        return Type::BATCH;
+      }
+      if (str_util::EndsWith(name_, "Cache")) {
+        return Type::CACHE;
+      }
+      if (name_ == "Concatenate") {
+        return Type::CONCATENATE;
+      }
+      if (name_ == "Filter") {
+        return Type::FILTER;
+      }
+      if (name_ == "FlatMap") {
+        return Type::FLAT_MAP;
+      }
+      if (name_ == "Interleave") {
+        return Type::INTERLEAVE;
+      }
+      if (name_ == "Map") {
+        return Type::MAP;
+      }
+      if (name_ == "MapAndBatch") {
+        return Type::MAP_AND_BATCH;
+      }
+      if (name_ == "PaddedBatch") {
+        return Type::PADDED_BATCH;
+      }
+      if (name_ == "ParallelInterleave") {
+        return Type::PARALLEL_INTERLEAVE;
+      }
+      if (name_ == "ParallelInterleaveV2") {
+        return Type::PARALLEL_INTERLEAVE_V2;
+      }
+      if (name_ == "ParallelMap") {
+        return Type::PARALLEL_MAP;
+      }
+      if (name_ == "Prefetch") {
+        return Type::PREFETCH;
+      }
+      if (str_util::EndsWith(name_, "Repeat")) {
+        return Type::REPEAT;
+      }
+      if (name_ == "Shuffle") {
+        return Type::SHUFFLE;
+      }
+      if (str_util::EndsWith(name_, "Skip")) {
+        return Type::SKIP;
+      }
+      if (str_util::EndsWith(name_, "Take")) {
+        return Type::TAKE;
+      }
+      if (name_ == "Zip") {
+        return Type::ZIP;
+      }
+      return Type::UNKNOWN;
+    }
+
+    mutex mu_;
+    const int64 id_;
+    const string name_;
+    const Type type_;
+    int64 processing_time_ GUARDED_BY(mu_) = 0;
+    int64 num_elements_ GUARDED_BY(mu_) = 0;
+    std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
+    std::map<string, int64> constant_params_ GUARDED_BY(mu_);
+    // Tunables are shared with the model during optimization.
+    std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_);
+    std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
+    std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+  };
+
+  std::vector<std::shared_ptr<Node::Tunable>> CollectTunables()
+      SHARED_LOCKS_REQUIRED(mu_);
+
+  int64 OutputTime() SHARED_LOCKS_REQUIRED(mu_);
+
+  int64 ProcessingTime() SHARED_LOCKS_REQUIRED(mu_);
+
+  // Used for coordination between different input pipeline threads. Exclusive
+  // access is required only when adding or removing nodes. Concurrent access to
+  // existing nodes is protected by a node mutex.
   mutex mu_;
   int64 id_counter_ GUARDED_BY(mu_) = 1;
   std::shared_ptr<Node> output_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/framework/model.proto b/tensorflow/core/framework/model.proto
deleted file mode 100644
index 2600000..0000000
--- a/tensorflow/core/framework/model.proto
+++ /dev/null
@@ -1,30 +0,0 @@
-syntax = "proto3";
-
-package tensorflow.data.model.proto;
-option cc_enable_arenas = true;
-
-message Model {
-  // Counter used for generating new node IDs.
-  int64 id_counter = 1;
-  // Nodes of this model.
-  repeated Node node = 2;
-  // The ID of the output node.
-  int64 output = 3;
-};
-
-message Node {
-  // The node ID.
-  int64 id = 1;
-  // The node name.
-  string name = 2;
-  // Input node IDs.
-  repeated int64 input = 3;
-  // Output node ID.
-  int64 output = 4;
-  // Number of elements produced by the node.
-  int64 num_elements = 5;
-  // The CPU time spent by running threads of this node.
-  int64 processing_time = 6;
-  // Key-value store for node metadata (e.g. batch size or parallelism).
-  map<string, int32> metadata = 7;
-};
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index bacc1d7..43ac1d0 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -372,6 +372,14 @@
                                  node_def.name());
 }
 
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+                         DataTypeVector* inputs) {
+  for (const auto& arg : op_def.input_arg()) {
+    TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
+  }
+  return Status::OK();
+}
+
 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
                          int output_port, DataType* output_type) {
   DataTypeVector output_types;
@@ -397,12 +405,18 @@
 
 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
                          DataTypeVector* inputs, DataTypeVector* outputs) {
-  for (const auto& arg : op_def.input_arg()) {
-    TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
-  }
+  TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
   return OutputTypesForNode(node_def, op_def, outputs);
 }
 
+Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
+                         int* num_outputs) {
+  DataTypeVector outputs;
+  TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
+  *num_outputs = outputs.size();
+  return Status::OK();
+}
+
 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
   if (node_def.op() != op_def.name()) {
     return errors::InvalidArgument("NodeDef op '", node_def.op(),
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 499034c..187bfa2 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -249,6 +249,10 @@
 // REQUIRES: ValidateOpDef(op_def).ok()
 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
                         int input_port, DataType* input_type);
+// Computes the input types for a specific node.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+                         DataTypeVector* inputs);
 // Computes the output type for a specific node output.
 // REQUIRES: ValidateOpDef(op_def).ok()
 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
@@ -261,6 +265,10 @@
 // REQUIRES: ValidateOpDef(op_def).ok()
 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
                          DataTypeVector* inputs, DataTypeVector* outputs);
+// Computes the number of outputs for a specific node.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
+                         int* num_outputs);
 
 // Validates that the NodeDef:
 // * Defines all expected attrs from the OpDef.
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index 74cc594..d9d4370 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -370,6 +370,48 @@
                       "Illegal op input name 'a:00");
 }
 
+TEST(InputTypesForNode, Simple) {
+  const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+                                   .Input("a: float")
+                                   .Input("b: int32")
+                                   .Output("c: string")
+                                   .Output("d: bool"));
+  const NodeDef node_def = ToNodeDef(
+      NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+  DataTypeVector types;
+  EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
+  EXPECT_EQ(types[0], DT_FLOAT);
+  EXPECT_EQ(types[1], DT_INT32);
+
+  DataType type;
+  EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
+  EXPECT_EQ(type, DT_FLOAT);
+  EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
+  EXPECT_EQ(type, DT_INT32);
+  EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
+TEST(OutputTypesForNode, Simple) {
+  const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+                                   .Input("a: float")
+                                   .Input("b: int32")
+                                   .Output("c: string")
+                                   .Output("d: bool"));
+  const NodeDef node_def = ToNodeDef(
+      NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+  DataTypeVector types;
+  EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
+  EXPECT_EQ(types[0], DT_STRING);
+  EXPECT_EQ(types[1], DT_BOOL);
+
+  DataType type;
+  EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
+  EXPECT_EQ(type, DT_STRING);
+  EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
+  EXPECT_EQ(type, DT_BOOL);
+  EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
 TEST(NameRangesForNodeTest, Simple) {
   const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
                                    .Input("a: float")
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index c694e10..3e34bf0 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -41,6 +41,7 @@
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/ptr_util.h"
 
 namespace tensorflow {
 
@@ -80,10 +81,8 @@
 
 // OpKernel ------------------------------------------------------------------
 
-// TODO(mrry): Convert to std::make_unique when available.
 OpKernel::OpKernel(OpKernelConstruction* context)
-    : OpKernel(context,
-               std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {}
+    : OpKernel(context, MakeUnique<const NodeDef>(context->def())) {}
 
 OpKernel::OpKernel(OpKernelConstruction* context,
                    std::unique_ptr<const NodeDef> node_def)
@@ -266,9 +265,12 @@
   params_->ensure_eigen_gpu_device();
   if (params_->eigen_gpu_device != nullptr) {
     Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
-    params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
-                                           params_->op_device_context,
-                                           eigen_gpu_allocator);
+    Status s = params_->device->ReinitializeGpuDevice(
+        this, params_->eigen_gpu_device, params_->op_device_context,
+        eigen_gpu_allocator);
+    if (!s.ok()) {
+      SetStatus(s);
+    }
   }
   if (params_->record_tensor_accesses) {
     referenced_tensors_.Init();
@@ -525,10 +527,8 @@
       return nullptr;
     }
   }
-  // TODO(rmlarsen): Use MakeUnique here. There is already a copy in
-  // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of
-  // general cleanup of ownership in this code.
-  std::unique_ptr<Tensor> output_tensor(new Tensor());
+
+  auto output_tensor = MakeUnique<Tensor>();
   CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
   return output_tensor;
 }
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index e752599..4bbd6c3 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -372,18 +372,37 @@
 template <typename ListType, typename ElementType>
 class OpArgIterator {
  public:
-  typedef OpArgIterator<ListType, ElementType> ME;
+  using iterator_category = std::forward_iterator_tag;
+  using value_type = ElementType;
+  using pointer = ElementType*;
+  using reference = ElementType&;
+  using difference_type = ptrdiff_t;
+
   OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
-  bool operator==(const ME& rhs) {
+
+  bool operator==(const OpArgIterator& rhs) {
     DCHECK(list_ == rhs.list_);
     return i_ == rhs.i_;
   }
-  bool operator!=(const ME& rhs) {
+
+  bool operator!=(const OpArgIterator& rhs) {
     DCHECK(list_ == rhs.list_);
     return i_ != rhs.i_;
   }
-  void operator++() { ++i_; }
-  ElementType& operator*() { return (*list_)[i_]; }
+
+  OpArgIterator operator++() {  // prefix ++it
+    ++i_;
+    return *this;
+  }
+
+  OpArgIterator operator++(int) {  // postfix it++
+    OpArgIterator old_value = *this;
+    ++i_;
+    return old_value;
+  }
+
+  reference operator*() { return (*list_)[i_]; }
+  pointer operator->() { return &(*list_)[i_]; }
 
  private:
   const ListType* const list_;
@@ -394,7 +413,7 @@
 // that are passed to the op as a single named argument.
 class OpInputList {
  public:
-  typedef OpArgIterator<OpInputList, const Tensor&> Iterator;
+  typedef OpArgIterator<OpInputList, const Tensor> Iterator;
   OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
   OpInputList(OpKernelContext* ctx, int start, int stop)
       : ctx_(ctx), start_(start), stop_(stop) {}
diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc
index dfc5aa7..75ed4a4 100644
--- a/tensorflow/core/framework/op_segment.cc
+++ b/tensorflow/core/framework/op_segment.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/core/framework/op_segment.h"
 
+#include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
@@ -99,4 +100,11 @@
   delete item;
 }
 
+bool OpSegment::ShouldOwnKernel(FunctionLibraryRuntime* lib,
+                                const string& node_op) {
+  // OpSegment should not own kernel if the node is stateless, or a function.
+  return lib->IsStateful(node_op) &&
+         lib->GetFunctionLibraryDefinition()->Find(node_op) == nullptr;
+}
+
 }  // end namespace tensorflow
diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h
index 4433a25..37d939e 100644
--- a/tensorflow/core/framework/op_segment.h
+++ b/tensorflow/core/framework/op_segment.h
@@ -60,6 +60,10 @@
   Status FindOrCreate(const string& session_handle, const string& node_name,
                       OpKernel** kernel, CreateKernelFn create_fn);
 
+  // Returns true if OpSegment should own the kernel.
+  static bool ShouldOwnKernel(FunctionLibraryRuntime* lib,
+                              const string& node_op);
+
  private:
   // op name -> OpKernel
   typedef std::unordered_map<string, OpKernel*> KernelMap;
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 516afa5..eb9c79f 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -948,9 +948,69 @@
   }
 }
 
+// Appends the spacing between elements for a given dim onto a result string
+void PrintDimSpacing(int dim_index, int num_dims, string* result) {
+  if (dim_index == num_dims - 1) {
+    strings::StrAppend(result, " ");
+    return;
+  }
+  for (int j = 0; j < num_dims - dim_index - 1; j++) {
+    strings::StrAppend(result, "\n");
+  }
+  for (int j = 0; j <= dim_index; j++) {
+    strings::StrAppend(result, " ");
+  }
+}
+
+// Print from left dim to right dim recursively.
+template <typename T>
+void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
+                   int64 num_elts_at_ends, int num_dims, const T* data,
+                   int64 data_index, string* result) {
+  // We have recursed beyond all the dimensions into a single element
+  // of the tensor.
+  if (dim_index == num_dims) {
+    strings::StrAppend(result, PrintOneElement(data[data_index]));
+    return;
+  }
+
+  strings::StrAppend(result, "[");
+  int64 element_count = shape[dim_index];
+  int64 start_of_end =
+      std::max(num_elts_at_ends, element_count - num_elts_at_ends);
+
+  // Loop every element of one dim.
+  int64 elements_per_iter = 1;
+  for (int i = dim_index + 1; i < num_dims; i++) {
+    elements_per_iter *= shape[i];
+  }
+  for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
+    if (i > 0) {
+      PrintDimSpacing(dim_index, num_dims, result);
+    }
+
+    // As for each element, print the sub-dim.
+    PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+                  data_index + elements_per_iter * i, result);
+  }
+  if (element_count > 2 * num_elts_at_ends) {
+    PrintDimSpacing(dim_index, num_dims, result);
+    strings::StrAppend(result, "...");
+  }
+  for (int64 i = start_of_end; i < element_count; i++) {
+    // As for each element, print the sub-dim.
+    PrintDimSpacing(dim_index, num_dims, result);
+    PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+                  data_index + elements_per_iter * i, result);
+  }
+
+  strings::StrAppend(result, "]");
+}
+
 template <typename T>
 string SummarizeArray(int64 limit, int64 num_elts,
-                      const TensorShape& tensor_shape, const char* data) {
+                      const TensorShape& tensor_shape, const char* data,
+                      const bool print_v2) {
   string ret;
   const T* array = reinterpret_cast<const T*>(data);
 
@@ -963,17 +1023,26 @@
     if (num_elts > limit) strings::StrAppend(&ret, "...");
     return ret;
   }
-  int64 data_index = 0;
-  const int shape_size = tensor_shape.dims();
-  PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
+  if (print_v2) {
+    const int num_dims = tensor_shape.dims();
+    PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
+  } else {
+    int64 data_index = 0;
+    const int shape_size = tensor_shape.dims();
+    PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
 
-  if (num_elts > limit) strings::StrAppend(&ret, "...");
+    if (num_elts > limit) strings::StrAppend(&ret, "...");
+  }
+
   return ret;
 }
 }  // namespace
 
-string Tensor::SummarizeValue(int64 max_entries) const {
+string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
   const int64 num_elts = NumElements();
+  if (max_entries < 0) {
+    max_entries = num_elts;
+  }
   size_t limit = std::min(max_entries, num_elts);
   if ((limit > 0) && (buf_ == nullptr)) {
     return strings::StrCat("uninitialized Tensor of ", num_elts,
@@ -982,50 +1051,54 @@
   const char* data = limit > 0 ? tensor_data().data() : nullptr;
   switch (dtype()) {
     case DT_HALF:
-      return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data);
+      return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
+                                         print_v2);
       break;
     case DT_FLOAT:
-      return SummarizeArray<float>(limit, num_elts, shape_, data);
+      return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
       break;
     case DT_DOUBLE:
-      return SummarizeArray<double>(limit, num_elts, shape_, data);
+      return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
       break;
     case DT_UINT32:
-      return SummarizeArray<uint32>(limit, num_elts, shape_, data);
+      return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
       break;
     case DT_INT32:
-      return SummarizeArray<int32>(limit, num_elts, shape_, data);
+      return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
       break;
     case DT_UINT8:
     case DT_QUINT8:
-      return SummarizeArray<uint8>(limit, num_elts, shape_, data);
+      return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
       break;
     case DT_UINT16:
     case DT_QUINT16:
-      return SummarizeArray<uint16>(limit, num_elts, shape_, data);
+      return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
       break;
     case DT_INT16:
     case DT_QINT16:
-      return SummarizeArray<int16>(limit, num_elts, shape_, data);
+      return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
       break;
     case DT_INT8:
     case DT_QINT8:
-      return SummarizeArray<int8>(limit, num_elts, shape_, data);
+      return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
       break;
     case DT_UINT64:
-      return SummarizeArray<uint64>(limit, num_elts, shape_, data);
+      return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
       break;
     case DT_INT64:
-      return SummarizeArray<int64>(limit, num_elts, shape_, data);
+      return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
       break;
     case DT_BOOL:
       // TODO(tucker): Is it better to emit "True False..."?  This
       // will emit "1 0..." which is more compact.
-      return SummarizeArray<bool>(limit, num_elts, shape_, data);
+      return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
       break;
     default: {
       // All irregular cases
       string ret;
+      if (print_v2) {
+        strings::StrAppend(&ret, "[");
+      }
       // TODO(irving): Don't call flat every time around this
       // loop.
       for (size_t i = 0; i < limit; ++i) {
@@ -1045,6 +1118,9 @@
         }
       }
       if (max_entries < num_elts) strings::StrAppend(&ret, "...");
+      if (print_v2) {
+        strings::StrAppend(&ret, "]");
+      }
       return ret;
     }
   }
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 696fd27..e412329 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -154,7 +154,7 @@
   /// Returns the estimated memory usage of this tensor.
   size_t TotalBytes() const;
 
-  // Returns the size of sallocated memory for this tensor.
+  // Returns the size of allocated memory for this tensor.
   size_t AllocatedBytes() const;
 
   /// Returns true iff this tensor is aligned.
@@ -430,7 +430,7 @@
       int64 begin) const;
 
   /// Render the first `max_entries` values in `*this` into a string.
-  string SummarizeValue(int64 max_entries) const;
+  string SummarizeValue(int64 max_entries, bool print_v2 = false) const;
 
   /// A human-readable summary of the tensor suitable for debugging.
   string DebugString() const;
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 9a78cdc..fc05c86 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -1295,6 +1295,63 @@
   EXPECT_EQ("one two three four five one...", x.SummarizeValue(6));
 }
 
+TEST(SummarizeValue, INT32_PRINT_V2) {
+  Tensor x = MkTensor<int>(DT_INT32, TensorShape({5}), {1, 2, 3, 4, 0});
+  EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+  EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+  EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+  EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+  x = MkTensor<int>(DT_INT32, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+  EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+  x = MkTensor<int>(DT_INT32, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+  EXPECT_EQ("[[[[1]]\n\n  [[2]]]\n\n\n [[[3]]\n\n  [[4]]]]",
+            x.SummarizeValue(16, true));
+  x = MkTensor<int>(DT_INT32, TensorShape({0}), {});
+  EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, INT32Dims_PRINT_V2) {
+  Tensor x = MkTensor<int>(DT_INT32, TensorShape({3, 4}),
+                           {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+  EXPECT_EQ("[[1 ... 4]\n ...\n [9 ... 12]]", x.SummarizeValue(1, true));
+  EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+            x.SummarizeValue(10, true));
+  EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+            x.SummarizeValue(-1, true));
+}
+
+TEST(SummarizeValue, FLOAT_PRINT_V2) {
+  Tensor x = MkTensor<float>(DT_FLOAT, TensorShape({5}), {1, 2, 3, 4, 0});
+  EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+  EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+  EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+  EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+  x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+  EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+  x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+  EXPECT_EQ("[[[[1]]\n\n  [[2]]]\n\n\n [[[3]]\n\n  [[4]]]]",
+            x.SummarizeValue(16, true));
+  x = MkTensor<float>(DT_FLOAT, TensorShape({0}), {});
+  EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, BOOL_PRINT_V2) {
+  Tensor x = MkTensor<bool>(DT_BOOL, TensorShape({5}), {false, true, true});
+  EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(16, true));
+  EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(-1, true));
+  EXPECT_EQ("[0 1 ... 0 1]", x.SummarizeValue(2, true));
+}
+
+TEST(SummarizeValue, STRING_PRINT_V2) {
+  Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}),
+                              {"one", "two", "three", "four", "five"});
+  EXPECT_EQ("[one two three four five]", x.SummarizeValue(16, true));
+  EXPECT_EQ("[one two three four five]", x.SummarizeValue(-1, true));
+  x = MkTensor<string>(DT_STRING, TensorShape({5, 1, 5}),
+                       {"one", "two", "three", "four", "five"});
+  EXPECT_EQ("[one two three four five one...]", x.SummarizeValue(6, true));
+}
+
 void BM_CreateAndDestroy(int iters) {
   TensorShape shape({10, 20});
   while (--iters) {
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 7399613..eeb5c14 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -1162,7 +1162,9 @@
     const NodeDef* node_def = node_defs_[pair->second.gdef_index];
     const OpDef* op_def;
     TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
-    if (key.second >= op_def->output_arg_size()) {
+    int num_outputs;
+    TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs));
+    if (key.second >= num_outputs) {
       // key's index out of bounds
       missing_unused_input_map_keys_->push_back(key);
     }
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 73142eb..3eef6bd 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -199,6 +199,10 @@
     .Output("y: T")
     .Attr("T: {float, int64}")
     .SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("TestVariadicOutput")
+    .Output("outputs: N * int32")
+    .Attr("N: int >= 0")
+    .SetShapeFn(shape_inference::UnknownShape);
 REGISTER_OP("TestDefaultAttr")
     .Attr("default_int: int=31415")
     .SetShapeFn(shape_inference::NoOutputs);
@@ -1463,12 +1467,15 @@
   opts.input_map[TensorId("DNE", 0)] = TensorId("input", 0);
   // Unused but not missing
   opts.input_map[TensorId("t1", 0)] = TensorId("W1", 0);
+  // Unused but not missing
+  opts.input_map[TensorId("variadic", 4)] = TensorId("input", 0);
   ExpectOK(
       R"EOF(
       node { name: 'W2' op: 'TestParams' }
       node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] }
       node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
-      node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] }
+      node { name: 'variadic' op: 'TestVariadicOutput'
+             attr { key: "N" value { i: 5 } } }
       )EOF",
       opts, &refiner, &results);
 
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 2e644fe..f5b0105 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -28,6 +28,7 @@
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/optimization_registry.h"
 #include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/graph/algorithm.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/node_builder.h"
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index bd0284d..b00196f 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -32,7 +32,7 @@
 namespace graph {
 
 // Converts "g" into its corresponding GraphDef "def".
-// DEPRECATED: call g->ToGraphDef(def) instead.
+ABSL_DEPRECATED("Call g->ToGraphDef(def) instead.")
 void ToGraphDef(Graph* g, GraphDef* def);
 
 // A few helpers to construct a graph.
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
index b97603c..e4f6bf7 100644
--- a/tensorflow/core/grappler/clusters/single_machine.cc
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -93,13 +93,13 @@
             strings::StrCat("Not able to parse GPU device name: ", dev.name()));
       }
       TfGpuId tf_gpu_id(parsed.id);
-      CudaGpuId cuda_gpu_id;
-      Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+      PlatformGpuId platform_gpu_id;
+      Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
       if (!s.ok()) {
         return errors::Unavailable("Unknown TF GPU device with id ",
                                    tf_gpu_id.value(), ": ", s.ToString());
       }
-      attr = GetLocalGPUInfo(cuda_gpu_id);
+      attr = GetLocalGPUInfo(platform_gpu_id);
     } else if (dev.device_type().find("XLA") == string::npos) {
       // Filter out the fake XLA devices to avoid double counting the actual
       // hardware resources that are available.
diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc
index a751972..567e7c0 100644
--- a/tensorflow/core/grappler/clusters/utils.cc
+++ b/tensorflow/core/grappler/clusters/utils.cc
@@ -70,13 +70,14 @@
   return device;
 }
 
-DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id) {
+DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id) {
   DeviceProperties device;
   device.set_type("GPU");
 
 #if GOOGLE_CUDA
   cudaDeviceProp properties;
-  cudaError_t error = cudaGetDeviceProperties(&properties, cuda_gpu_id.value());
+  cudaError_t error =
+      cudaGetDeviceProperties(&properties, platform_gpu_id.value());
   if (error != cudaSuccess) {
     device.set_type("UNKNOWN");
     LOG(ERROR) << "Failed to get device properties, error code: " << error;
@@ -122,15 +123,15 @@
   } else if (device.type == "GPU") {
     if (device.has_id) {
       TfGpuId tf_gpu_id(device.id);
-      CudaGpuId cuda_gpu_id;
-      Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+      PlatformGpuId platform_gpu_id;
+      Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
       if (!s.ok()) {
         LOG(ERROR) << s;
         return unknown;
       }
-      return GetLocalGPUInfo(cuda_gpu_id);
+      return GetLocalGPUInfo(platform_gpu_id);
     } else {
-      return GetLocalGPUInfo(CudaGpuId(0));
+      return GetLocalGPUInfo(PlatformGpuId(0));
     }
   }
   return unknown;
diff --git a/tensorflow/core/grappler/clusters/utils.h b/tensorflow/core/grappler/clusters/utils.h
index ca15c48..f0a342b 100644
--- a/tensorflow/core/grappler/clusters/utils.h
+++ b/tensorflow/core/grappler/clusters/utils.h
@@ -28,7 +28,7 @@
 
 // Returns the DeviceProperties for the specified GPU attached to the server on
 // which grappler is running.
-DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id);
+DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id);
 
 // Returns the DeviceProperties of the specified device
 DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device);
diff --git a/tensorflow/core/grappler/clusters/utils_test.cc b/tensorflow/core/grappler/clusters/utils_test.cc
index 74218ad..3863d62 100644
--- a/tensorflow/core/grappler/clusters/utils_test.cc
+++ b/tensorflow/core/grappler/clusters/utils_test.cc
@@ -31,22 +31,22 @@
   LOG(INFO) << "CUDA is enabled.";
   DeviceProperties properties;
 
-  // Invalid CUDA GPU ID.
-  properties = GetLocalGPUInfo(CudaGpuId(100));
+  // Invalid platform GPU ID.
+  properties = GetLocalGPUInfo(PlatformGpuId(100));
   EXPECT_EQ("UNKNOWN", properties.type());
 
-  // Succeed when a valid CUDA GPU id was inserted.
-  properties = GetLocalGPUInfo(CudaGpuId(0));
+  // Succeed when a valid platform GPU id was inserted.
+  properties = GetLocalGPUInfo(PlatformGpuId(0));
   EXPECT_EQ("GPU", properties.type());
   EXPECT_EQ("NVIDIA", properties.vendor());
 #else
   LOG(INFO) << "CUDA is not enabled.";
   DeviceProperties properties;
 
-  properties = GetLocalGPUInfo(CudaGpuId(0));
+  properties = GetLocalGPUInfo(PlatformGpuId(0));
   EXPECT_EQ("GPU", properties.type());
 
-  properties = GetLocalGPUInfo(CudaGpuId(100));
+  properties = GetLocalGPUInfo(PlatformGpuId(100));
   EXPECT_EQ("GPU", properties.type());
 #endif
 }
@@ -74,20 +74,20 @@
   EXPECT_EQ("NVIDIA", properties.vendor());
 #endif
 
-  // TF to CUDA GPU id mapping entry doesn't exist.
+  // TF to platform GPU id mapping entry doesn't exist.
   device.has_id = true;
   device.id = 0;
   properties = GetDeviceInfo(device);
   EXPECT_EQ("UNKNOWN", properties.type());
 
 #if GOOGLE_CUDA
-  // Invalid CUDA GPU id.
-  GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(0), CudaGpuId(100));
+  // Invalid platform GPU id.
+  GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(0), PlatformGpuId(100));
   properties = GetDeviceInfo(device);
   EXPECT_EQ("UNKNOWN", properties.type());
 
-  // Valid CUDA GPU id.
-  GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(1), CudaGpuId(0));
+  // Valid platform GPU id.
+  GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(1), PlatformGpuId(0));
   device.id = 1;
   properties = GetDeviceInfo(device);
   EXPECT_EQ("GPU", properties.type());
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index d273edd..56c8339d 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -260,13 +260,13 @@
 }
 
 bool IsEnqueue(const NodeDef& n) {
-  return (n.op().find("Enqueue") != std::string::npos &&
-          n.op().find("EnqueueMany") == std::string::npos);
+  return (n.op().find("Enqueue") != string::npos &&
+          n.op().find("EnqueueMany") == string::npos);
 }
 
 bool IsDequeue(const NodeDef& n) {
-  return (n.op().find("Dequeue") != std::string::npos &&
-          n.op().find("DequeueMany") == std::string::npos);
+  return (n.op().find("Dequeue") != string::npos &&
+          n.op().find("DequeueMany") == string::npos);
 }
 
 bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index aad00ce..5415324b 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -127,7 +127,7 @@
 
       // For filename input, the file size can also be useful.
       if (op_def && i < op_def->input_arg_size() &&
-          op_def->input_arg(i).name().find("filename") != std::string::npos) {
+          op_def->input_arg(i).name().find("filename") != string::npos) {
         Tensor tensor;
         if (!tensor.FromProto(t)) {
           continue;
@@ -153,7 +153,7 @@
     // When the input is a handle (e.g. look up table handle), the information
     // in the op itself is not sufficient to predict the op memory.
     if (op_def && i < op_def->input_arg_size() &&
-        op_def->input_arg(i).name().find("handle") != std::string::npos) {
+        op_def->input_arg(i).name().find("handle") != string::npos) {
       string new_key = strings::StrCat("parent_", i, "_op");
       AttrValue attr;
       attr.set_s(input_node->op());
@@ -209,13 +209,13 @@
   if (DeviceNameUtils::ParseFullName(device_str, &parsed)) {
     if (parsed.type == "GPU") {
       TfGpuId tf_gpu_id(parsed.id);
-      CudaGpuId cuda_gpu_id;
-      Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+      PlatformGpuId platform_gpu_id;
+      Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
       if (!s.ok()) {
         // We are probably running simulation without linking cuda libraries.
-        cuda_gpu_id = CudaGpuId(parsed.id);
+        platform_gpu_id = PlatformGpuId(parsed.id);
       }
-      return GetLocalGPUInfo(cuda_gpu_id);
+      return GetLocalGPUInfo(platform_gpu_id);
     } else if (parsed.type == "CPU") {
       return GetLocalCPUInfo();
     }
@@ -320,8 +320,8 @@
                  buckets_.begin(), std::plus<uint64>());
 }
 
-std::string TensorSizeHistogram::ToString() const {
-  std::string r;
+string TensorSizeHistogram::ToString() const {
+  string r;
   char buf[200];
   snprintf(buf, sizeof(buf), "Count: %lld, Average: ", num_elem_);
   r.append(buf);
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h
index d2c7c67..5fd6717 100644
--- a/tensorflow/core/grappler/costs/utils.h
+++ b/tensorflow/core/grappler/costs/utils.h
@@ -80,7 +80,7 @@
   uint64 Max() const { return max_; }
   uint64 NumElem() const { return num_elem_; }
   uint64 SumElem() const { return sum_elem_; }
-  std::string ToString() const;
+  string ToString() const;
 
  protected:
   const int Index(const uint64 value) const;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index 02a379f..80889af 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -1999,13 +1999,13 @@
 
   // Helper lambda to extract port num from _Send and _Recv op name.
   auto get_port_num = [](const string& name) -> int {
-    if (name.find("bn_0") != std::string::npos) {
+    if (name.find("bn_0") != string::npos) {
       return 0;
-    } else if (name.find("bn_1") != std::string::npos) {
+    } else if (name.find("bn_1") != string::npos) {
       return 1;
-    } else if (name.find("bn_2") != std::string::npos) {
+    } else if (name.find("bn_2") != string::npos) {
       return 2;
-    } else if (name.find("bn_minus1") != std::string::npos) {
+    } else if (name.find("bn_minus1") != string::npos) {
       return -1;
     }
     return -999;
diff --git a/tensorflow/core/grappler/inputs/utils.cc b/tensorflow/core/grappler/inputs/utils.cc
index 5029dff..def9198 100644
--- a/tensorflow/core/grappler/inputs/utils.cc
+++ b/tensorflow/core/grappler/inputs/utils.cc
@@ -14,10 +14,11 @@
 ==============================================================================*/
 
 #include "tensorflow/core/grappler/inputs/utils.h"
-#include "tensorflow/core/platform/env.h"
 
 #include <vector>
 
+#include "tensorflow/core/platform/env.h"
+
 namespace tensorflow {
 namespace grappler {
 
@@ -29,12 +30,12 @@
   return FilesExist(std::vector<string>(files.begin(), files.end()), nullptr);
 }
 
-bool FileExists(const std::string& file, Status* status) {
+bool FileExists(const string& file, Status* status) {
   *status = Env::Default()->FileExists(file);
   return status->ok();
 }
 
-Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path,
+Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
                             GraphDef* result) {
   Status status;
   if (FileExists(graph_def_pbtxt_path, &status)) {
diff --git a/tensorflow/core/grappler/inputs/utils.h b/tensorflow/core/grappler/inputs/utils.h
index 627dd53..4b9cb0a 100644
--- a/tensorflow/core/grappler/inputs/utils.h
+++ b/tensorflow/core/grappler/inputs/utils.h
@@ -29,9 +29,9 @@
                 std::vector<Status>* status = nullptr);
 bool FilesExist(const std::set<string>& files);
 
-bool FileExists(const std::string& file, Status* status);
+bool FileExists(const string& file, Status* status);
 
-Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path,
+Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
                             GraphDef* result);
 
 }  // end namespace grappler
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index e78239b..3521669 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -491,7 +491,7 @@
     }
   }
   // Queue ops modify the queue which is a side effect.
-  if (node.op().find("Queue") != std::string::npos) {
+  if (node.op().find("Queue") != string::npos) {
     return false;
   }
   return !ModifiesInputsInPlace(node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index f094c15..261dee4 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -8,10 +8,6 @@
 
 # Platform specific build config
 load(
-    "//tensorflow/core:platform/default/build_config.bzl",
-    "tf_protos_grappler",
-)
-load(
     "//tensorflow/core:platform/default/build_config_root.bzl",
     "if_static",
 )
@@ -97,7 +93,6 @@
     deps = [
         ":evaluation_utils",
         ":graph_optimizer",
-        ":symbolic_shapes",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
@@ -107,6 +102,7 @@
         "//tensorflow/core/grappler:utils",
         "//tensorflow/core/grappler/clusters:cluster",
         "//tensorflow/core/grappler/costs:graph_properties",
+        "//tensorflow/core/grappler/utils:symbolic_shapes",
     ],
 )
 
@@ -261,7 +257,6 @@
         ":constant_folding",
         ":graph_optimizer",
         ":graph_optimizer_stage",
-        ":symbolic_shapes",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
@@ -270,6 +265,7 @@
         "//tensorflow/core/grappler:op_types",
         "//tensorflow/core/grappler:utils",
         "//tensorflow/core/grappler/costs:graph_properties",
+        "//tensorflow/core/grappler/utils:symbolic_shapes",
         "//tensorflow/core/grappler/utils:topological_sort",
     ],
 )
@@ -515,6 +511,7 @@
         ":custom_graph_optimizer_registry",
         ":debug_stripper",
         ":dependency_optimizer",
+        ":experimental_implementation_selector",
         ":function_optimizer",
         ":graph_optimizer",
         ":layout_optimizer",
@@ -647,7 +644,6 @@
     visibility = ["//visibility:public"],
     deps = [
         ":graph_optimizer",
-        ":symbolic_shapes",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
@@ -657,6 +653,7 @@
         "//tensorflow/core/grappler:utils",
         "//tensorflow/core/grappler/costs:graph_properties",
         "//tensorflow/core/grappler/utils:frame",
+        "//tensorflow/core/grappler/utils:symbolic_shapes",
     ],
 )
 
@@ -714,31 +711,6 @@
 )
 
 cc_library(
-    name = "symbolic_shapes",
-    srcs = ["symbolic_shapes.cc"],
-    hdrs = ["symbolic_shapes.h"],
-    visibility = ["//visibility:public"],
-    deps = [
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-    ] + tf_protos_grappler(),
-)
-
-tf_cc_test(
-    name = "symbolic_shapes_test",
-    srcs = ["symbolic_shapes_test.cc"],
-    deps = [
-        ":symbolic_shapes",
-        "//tensorflow/core:framework",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core:test",
-        "//tensorflow/core:test_main",
-    ],
-)
-
-cc_library(
     name = "debug_stripper",
     srcs = ["debug_stripper.cc"],
     hdrs = [
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 11ce121..76a9dca 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -35,8 +35,8 @@
 #include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
 #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
 #include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
 #include "tensorflow/core/grappler/utils/topological_sort.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
@@ -1325,38 +1325,26 @@
   }
 
   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
-    const string node_name = node->name();
     NodeDef* x;
     NodeDef* y;
     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
     bool updated = false;
-    if (IsAdd(*node)) {
-      if (IsNeg(*x)) {
-        // (-a) + b = b - a
-        node->set_op("Sub");
-        node->mutable_input()->SwapElements(0, 1);
-        node->set_input(1, x->input(0));
-        node->add_input(AsControlDependency(x->name()));
-        ctx().node_map->AddOutput(NodeName(x->input(0)), node_name);
-        updated = true;
-      } else if (IsNeg(*y)) {
-        // a + (-b) = a - b
-        node->set_op("Sub");
-        node->set_input(1, y->input(0));
-        node->add_input(AsControlDependency(y->name()));
-        ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
-        updated = true;
-      }
-    } else if (IsSub(*node)) {
-      if (IsNeg(*y)) {
-        // a - (-b) = a + b
-        node->set_op("Add");
-        node->set_input(1, y->input(0));
-        node->add_input(AsControlDependency(y->name()));
-        ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
-        updated = true;
-      }
+    if (IsNeg(*y)) {
+      // a - (-b) = a + b or  a + (-b) = a - b
+      ForwardControlDependencies(node, {y});
+      ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
+      node->set_op(IsAdd(*node) ? "Sub" : "Add");
+      node->set_input(1, y->input(0));
+      updated = true;
+    } else if (IsAdd(*node) && IsNeg(*x)) {
+      // (-a) + b = b - a
+      ForwardControlDependencies(node, {x});
+      ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
+      node->set_op("Sub");
+      node->mutable_input()->SwapElements(0, 1);
+      node->set_input(1, x->input(0));
+      updated = true;
     }
     if (updated) {
       AddToOptimizationQueue(node);
@@ -2379,26 +2367,24 @@
   }
 
   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
-    const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1];
-    for (int i = 0; i < p.shape().dim_size(); ++i) {
-      if (p.shape().dim(i).size() < 0) {
+    const auto& pow_props =
+        ctx().graph_properties->GetInputProperties(node->name())[1];
+    for (int i = 0; i < pow_props.shape().dim_size(); ++i) {
+      if (pow_props.shape().dim(i).size() < 0) {
         // skip if p is is not fully defined.
         return Status::OK();
       }
     }
-    if (TensorShape::IsValid(p.shape()) && p.has_value()) {
-      Tensor pow(p.dtype(), p.shape());
-      if (!pow.FromProto(p.value())) {
+    if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) {
+      Tensor pow(pow_props.dtype(), pow_props.shape());
+      if (!pow.FromProto(pow_props.value())) {
         return errors::InvalidArgument("Cannot parse tensor from proto: ",
-                                       p.value().DebugString());
+                                       pow_props.value().DebugString());
       }
 
       complex128 prev, curr;
       for (int i = 0; i < pow.NumElements(); ++i) {
-        if (!GetElementUnexhaustive(pow, i,
-                                    {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
-                                     DT_COMPLEX64, DT_COMPLEX128},
-                                    &curr)) {
+        if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) {
           // input data type is not supported by Pow. Skip.
           return Status::OK();
         }
@@ -2411,12 +2397,19 @@
       NodeDef *x, *y;
       TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
       TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
+      const auto& value_props =
+          ctx().graph_properties->GetInputProperties(node->name())[0];
+      const TensorShapeProto& output_shape =
+          ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
       if (curr == complex128(2, 0)) {
         node->set_op("Square");
         node->set_input(1, AsControlDependency(y->name()));
         AddToOptimizationQueue(node);
         AddToOptimizationQueue(y);
-      } else if (curr == complex128(1, 0)) {
+      } else if (curr == complex128(1, 0) &&
+                 ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
+        // Pow could be used to broadcast, so make sure the shapes of the two
+        // arguments are identical before replacing Pow with Identity.
         node->set_op("Identity");
         node->set_input(1, AsControlDependency(y->name()));
         AddToOptimizationQueue(node);
@@ -2426,20 +2419,20 @@
         node->set_input(1, AsControlDependency(y->name()));
         AddToOptimizationQueue(node);
         AddToOptimizationQueue(y);
-      } else if (curr == complex128(0, 0)) {
-        const auto& b =
-            ctx().graph_properties->GetInputProperties(node->name())[0];
-        for (int i = 0; i < b.shape().dim_size(); ++i) {
-          if (b.shape().dim(i).size() < 0) {
+      } else if (curr == complex128(0, 0) &&
+                 ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
+        for (int i = 0; i < value_props.shape().dim_size(); ++i) {
+          if (value_props.shape().dim(i).size() < 0) {
             // skip if b is is not fully defined.
             return Status::OK();
           }
         }
-        if (TensorShape::IsValid(b.shape()) && b.has_value()) {
-          Tensor base(b.dtype(), b.shape());
-          if (!base.FromProto(b.value())) {
+        if (TensorShape::IsValid(value_props.shape()) &&
+            value_props.has_value()) {
+          Tensor base(value_props.dtype(), value_props.shape());
+          if (!base.FromProto(value_props.value())) {
             return errors::InvalidArgument("Cannot parse tensor from proto: ",
-                                           b.value().DebugString());
+                                           value_props.value().DebugString());
           }
           node->set_op("Const");
           Tensor c(base.dtype(), base.shape());
@@ -2597,12 +2590,10 @@
   ~ConvertExpm1Stage() override = default;
 
   bool IsSupported(const NodeDef* node) const override {
-    if (!IsSub(*node))
-      return false;
+    if (!IsSub(*node)) return false;
 
     NodeDef* input;
-    if (!GetInputNode(node->input(0), &input).ok())
-      return false;
+    if (!GetInputNode(node->input(0), &input).ok()) return false;
 
     return IsExp(*input);
   }
@@ -2622,10 +2613,8 @@
       return Status::OK();
     }
 
-    const auto& t =
-        ctx().graph_properties->GetInputProperties(exp->name())[0];
-    const auto& c =
-        ctx().graph_properties->GetInputProperties(node->name())[1];
+    const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
+    const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
     for (int k = 0; k < c.shape().dim_size(); ++k) {
       // Skip if c shape is not fully determined.
       if (c.shape().dim(k).size() < 0) {
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 39517ed..77f3c64 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -581,7 +581,7 @@
   const NodeDef* new_const = node_map.GetNode(optimized_const_name);
   ASSERT_NE(new_const, nullptr);
   EXPECT_EQ("^x", new_const->input(0));
-  EXPECT_EQ(std::string("\0\0\0@", 4),
+  EXPECT_EQ(string("\0\0\0@", 4),
             new_const->attr().at("value").tensor().tensor_content());
 
   const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
@@ -625,7 +625,7 @@
   const NodeDef* new_const = node_map.GetNode(optimized_const_name);
   ASSERT_NE(new_const, nullptr);
   EXPECT_EQ("^x", new_const->input(0));
-  EXPECT_EQ(std::string("\0\0\0@", 4),
+  EXPECT_EQ(string("\0\0\0@", 4),
             new_const->attr().at("value").tensor().tensor_content());
 
   const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
@@ -2353,9 +2353,14 @@
   Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
   Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
   Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
-  auto add_all = ops::AddN(s.WithOpName("add_all"),
-                           {add_x_y, add_negx_y, add_x_negy, add_negx_negy,
-                            sub_x_y, sub_negx_y, sub_x_negy, sub_negx_negy});
+  Output neg_x_with_dep = ops::Neg(
+      s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
+  Output add_negx_with_dep_y =
+      ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
+  auto add_all =
+      ops::AddN(s.WithOpName("add_all"),
+                {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
+                 sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
 
   GrapplerItem item;
   item.fetch = {"add_all"};
@@ -2370,7 +2375,7 @@
   GraphDef output;
   ArithmeticOptimizer optimizer;
   EnableOnlyRemoveNegation(&optimizer);
-  OptimizeAndPrune(&optimizer, &item, &output);
+  OptimizeTwice(&optimizer, &item, &output);
 
   EXPECT_EQ(item.graph.node_size(), output.node_size());
   int found = 0;
@@ -2379,42 +2384,43 @@
     if (node.name() == "Add_negx_y") {
       ++found;
       EXPECT_EQ("Sub", node.op());
-      EXPECT_EQ(3, node.input_size());
+      EXPECT_EQ(2, node.input_size());
       EXPECT_EQ("y", node.input(0));
       EXPECT_EQ("x", node.input(1));
-      EXPECT_EQ("^Neg_x", node.input(2));
     } else if (node.name() == "Add_x_negy") {
       ++found;
       EXPECT_EQ("Sub", node.op());
-      EXPECT_EQ(3, node.input_size());
+      EXPECT_EQ(2, node.input_size());
       EXPECT_EQ("x", node.input(0));
       EXPECT_EQ("y", node.input(1));
-      EXPECT_EQ("^Neg_y", node.input(2));
     } else if (node.name() == "Add_negx_negy") {
       ++found;
       EXPECT_EQ("Sub", node.op());
-      EXPECT_EQ(3, node.input_size());
-      EXPECT_EQ("Neg_y", node.input(0));
-      EXPECT_EQ("x", node.input(1));
-      EXPECT_EQ("^Neg_x", node.input(2));
+      EXPECT_EQ(2, node.input_size());
+      EXPECT_EQ("Neg_x", node.input(0));
+      EXPECT_EQ("y", node.input(1));
     } else if (node.name() == "Sub_x_negy") {
       ++found;
       EXPECT_EQ("Add", node.op());
-      EXPECT_EQ(3, node.input_size());
+      EXPECT_EQ(2, node.input_size());
       EXPECT_EQ("x", node.input(0));
       EXPECT_EQ("y", node.input(1));
-      EXPECT_EQ("^Neg_y", node.input(2));
     } else if (node.name() == "Sub_negx_negy") {
       ++found;
       EXPECT_EQ("Sub", node.op());
-      EXPECT_EQ(4, node.input_size());
+      EXPECT_EQ(2, node.input_size());
       EXPECT_EQ("y", node.input(0));
       EXPECT_EQ("x", node.input(1));
-      EXPECT_EQ("^Neg_y", node.input(2));
-      EXPECT_EQ("^Neg_x", node.input(3));
+    } else if (node.name() == "Add_negx_with_dep_y") {
+      ++found;
+      EXPECT_EQ("Sub", node.op());
+      EXPECT_EQ(3, node.input_size());
+      EXPECT_EQ("y", node.input(0));
+      EXPECT_EQ("x", node.input(1));
+      EXPECT_EQ("^Add_x_y", node.input(2));
     }
   }
-  EXPECT_EQ(5, found);
+  EXPECT_EQ(6, found);
 
   auto tensors = EvaluateNodes(output, item.fetch, feed);
   EXPECT_EQ(1, tensors.size());
@@ -2468,6 +2474,9 @@
   auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
   auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
+  auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
+  auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
+  auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
   Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
   Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
   Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
@@ -2475,21 +2484,24 @@
   Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
   Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
   Output out = ops::Pow(s.WithOpName("out"), x, y);
+  Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
+  Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
 
   GrapplerItem item;
-  item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"};
+  item.fetch = {"out2",  "out1", "out.5",      "out0",      "out_.5",
+                "out_1", "out",  "out_bcast1", "out_bcast2"};
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
-  EXPECT_EQ(7, tensors_expected.size());
+  EXPECT_EQ(9, tensors_expected.size());
 
   GraphDef got;
   ArithmeticOptimizer optimizer;
   EnableOnlyConvertPow(&optimizer);
   OptimizeAndPrune(&optimizer, &item, &got);
   auto tensors = EvaluateNodes(got, item.fetch);
-  EXPECT_EQ(7, tensors.size());
+  EXPECT_EQ(9, tensors.size());
 
-  for (int i = 0; i < 7; ++i) {
+  for (int i = 0; i < tensors.size(); ++i) {
     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
   }
@@ -2503,6 +2515,9 @@
   AddNode("y_.5", "Const", {}, {}, &want);
   AddNode("y_1", "Const", {}, {}, &want);
   AddNode("y", "Const", {}, {}, &want);
+  AddNode("z", "Const", {}, {}, &want);
+  AddNode("ones", "Const", {}, {}, &want);
+  AddNode("zeros", "Const", {}, {}, &want);
   AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
   AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
   AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
@@ -2511,6 +2526,8 @@
   AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
   AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
   AddNode("out", "Pow", {"x", "y"}, {}, &want);
+  AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
+  AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
 
   CompareGraphs(want, got);
 }
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 99737a7..cfbd298 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -32,8 +32,8 @@
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
 #include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -437,25 +437,6 @@
 }
 
 namespace {
-bool ShapesEqual(const TensorShapeProto& shape1,
-                 const TensorShapeProto& shape2) {
-  if (shape1.unknown_rank() || shape2.unknown_rank()) {
-    return false;
-  }
-  if (shape1.dim_size() != shape2.dim_size()) {
-    return false;
-  }
-  for (int i = 0; i < shape1.dim_size(); ++i) {
-    if (shape1.dim(i).size() != shape2.dim(i).size()) {
-      return false;
-    }
-    if (shape1.dim(i).size() == -1 || shape2.dim(i).size() == -1) {
-      return false;
-    }
-  }
-  return true;
-}
-
 bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
                   BCast::Vec* shape, int64* min_id) {
   if (shape_node.op() == "Shape") {
@@ -2348,7 +2329,8 @@
         properties.GetInputProperties(node->name())[1].shape();
     const bool x_is_zero = IsZeros(*x);
     const bool x_is_one = x_is_zero ? false : IsOnes(*x);
-    const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
+    const bool y_matches_output_shape =
+        ShapesSymbolicallyEqual(output_shape, y_shape);
     if (y_matches_output_shape &&
         ((is_mul && x_is_one) || (is_add && x_is_zero))) {
       // 1 * y = y or 0 + y = y.
@@ -2378,7 +2360,8 @@
         properties.GetInputProperties(node->name())[0].shape();
     const bool y_is_zero = IsZeros(*y);
     const bool y_is_one = y_is_zero ? false : IsOnes(*y);
-    const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
+    const bool x_matches_output_shape =
+        ShapesSymbolicallyEqual(output_shape, x_shape);
     if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
                                    ((is_add || is_sub) && y_is_zero))) {
       // x * 1 = x or x / 1 = x or x +/- 0 = x
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 530c957..79d5fe8 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -19,7 +19,6 @@
         "//tensorflow/core/grappler:op_types",
         "//tensorflow/core/grappler:utils",
         "//tensorflow/core/grappler/clusters:cluster",
-        "//tensorflow/core/kernels:cast_op",
         "//tensorflow/core/grappler/utils:topological_sort",
         "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
         "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
@@ -50,14 +49,15 @@
     visibility = ["//visibility:public"],
     deps = [
         ":graph_utils",
+        ":function_utils",
         "//tensorflow/core/grappler:mutable_graph_view",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core/grappler:grappler_item",
         "//tensorflow/core/grappler:op_types",
         "//tensorflow/core/grappler:utils",
-        "//tensorflow/core/kernels:cast_op",
         "//tensorflow/core/kernels:functional_ops",
+        "//tensorflow/core/kernels:control_flow_ops",
         "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
         "//tensorflow/core:lib_internal",
     ] + tf_protos_all(),
@@ -68,6 +68,7 @@
     srcs = ["fusion_utils_test.cc"],
     visibility = ["//visibility:public"],
     deps = [
+        ":function_utils",
         ":fusion_utils",
         ":graph_utils",
         "//tensorflow/core:framework",
@@ -79,6 +80,40 @@
 )
 
 cc_library(
+    name = "function_utils",
+    srcs = ["function_utils.cc"],
+    hdrs = [
+        "function_utils.h",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/grappler:mutable_graph_view",
+        "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler:utils",
+    ] + tf_protos_all(),
+)
+
+tf_cc_test(
+    name = "function_utils_test",
+    srcs = ["function_utils_test.cc"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":function_utils",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+        "//tensorflow/core/kernels:cast_op",
+        "//tensorflow/tools/graph_transforms:transform_utils",
+    ],
+)
+
+cc_library(
     name = "graph_utils",
     srcs = ["graph_utils.cc"],
     hdrs = [
@@ -107,7 +142,6 @@
         "//tensorflow/core/grappler:grappler_item",
         "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
         "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
-        "//tensorflow/core/kernels:cast_op",
     ],
 )
 
@@ -139,7 +173,9 @@
     ],
     visibility = ["//visibility:public"],
     deps = [
+        ":function_utils",
         ":graph_utils",
+        ":vectorization_utils",
         "//tensorflow/core:lib",
         "//tensorflow/core/grappler:mutable_graph_view",
         "//tensorflow/core/grappler:grappler_item",
@@ -164,7 +200,6 @@
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
         "//tensorflow/core/grappler:grappler_item",
-        "//tensorflow/core/kernels:cast_op",  # Must be linked for the testlib functions to work.
     ],
 )
 
@@ -256,7 +291,6 @@
         "//tensorflow/core/grappler:op_types",
         "//tensorflow/core/grappler:utils",
         "//tensorflow/core/grappler/clusters:cluster",
-        "//tensorflow/core/kernels:cast_op",
         "//tensorflow/core/grappler/utils:topological_sort",
         "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
         "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
@@ -275,6 +309,43 @@
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
         "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/kernels:control_flow_ops",
+    ],
+)
+
+cc_library(
+    name = "map_parallelization",
+    srcs = ["map_parallelization.cc"],
+    hdrs = [
+        "map_parallelization.h",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":graph_utils",
+        "//tensorflow/core/grappler:mutable_graph_view",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler:op_types",
+        "//tensorflow/core/grappler:utils",
+        "//tensorflow/core/grappler/clusters:cluster",
+        "//tensorflow/core/grappler/utils:topological_sort",
+        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+    ] + tf_protos_all(),
+)
+
+tf_cc_test(
+    name = "map_parallelization_test",
+    srcs = ["map_parallelization_test.cc"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":graph_utils",
+        ":map_parallelization",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/grappler:grappler_item",
     ],
 )
 
@@ -355,6 +426,7 @@
         ":map_and_batch_fusion",
         ":map_and_filter_fusion",
         ":map_fusion",
+        ":map_parallelization",
         ":map_vectorization",
         ":noop_elimination",
         ":shuffle_and_repeat_fusion",
@@ -375,3 +447,42 @@
         "//tensorflow/core/grappler:grappler_item",
     ],
 )
+
+cc_library(
+    name = "vectorization_utils",
+    srcs = ["vectorization_utils.cc"],
+    hdrs = [
+        "vectorization_utils.h",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":function_utils",
+        ":graph_utils",
+        "@com_google_absl//absl/strings",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core/grappler:mutable_graph_view",
+        "//tensorflow/core/grappler:utils",
+        "//tensorflow/core/grappler/utils:functions",
+    ] + tf_protos_all(),
+)
+
+tf_cc_test(
+    name = "vectorization_utils_test",
+    srcs = ["vectorization_utils_test.cc"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":function_utils",
+        ":vectorization_utils",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+        "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+        "//tensorflow/core/kernels:cast_op",
+        "//tensorflow/tools/graph_transforms:transform_utils",
+    ] + tf_protos_all(),
+)
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc
new file mode 100644
index 0000000..e95ea1a
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc
@@ -0,0 +1,196 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+namespace {
+
+template <typename Predicate, typename Collection>
+std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
+                                                const Collection& collection) {
+  std::vector<int> indices = {};
+  unsigned idx = 0;
+  for (auto&& element : collection) {
+    if (predicate(element)) {
+      indices.push_back(idx);
+    }
+    idx++;
+  }
+  return indices;
+}
+
+}  // namespace
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
+                                             const string& output, int position)
+    : node_name(node_name), node_output(output), position(position) {
+  full_str = strings::StrCat(node_name, ":", node_output, ":", position);
+}
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) {
+  // Parses node_name:node_output:position string into its components.
+  full_str = input;
+  StringPiece capture;
+  StringPiece remaining;
+
+  // Parse "node_name"
+  if (strings::Scanner(input)
+          .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
+          .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
+          .GetResult(&remaining, &capture)) {
+    node_name = string(capture.data(), capture.size());
+  }
+
+  // Parse "node_output" if it exists
+  if (strings::Scanner(remaining)
+          .OneLiteral(":")
+          .RestartCapture()
+          .One(strings::Scanner::LETTER)
+          .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
+          .GetResult(&remaining, &capture)) {
+    node_output = string(capture.data(), capture.size());
+  }
+
+  // Parse "position" if it exists
+  if (strings::Scanner(remaining)
+          .OneLiteral(":")
+          .RestartCapture()
+          .Many(strings::Scanner::DIGIT)
+          .GetResult(nullptr, &capture)) {
+    CHECK(strings::safe_strto32(capture, &position));
+  }
+}
+
+// TODO(rachelim): Create a utility class similar to MutableGraphView for
+// FunctionDefs, and use that to manipulate functions. It'll be more
+// performant if we kept mappings of nodes->inputs/outputs, so that we don't
+// have to search over all nodes each time.
+// Note that we're not using GrapplerFunctionItem because it doesn't cover
+// some of our desired uses (eg changing the outputs of a function), and the
+// FunctionDef -> GraphDef conversion isn't really necessary in this case.
+void ReplaceReferences(const string& from, const string& to,
+                       FunctionDef* func) {
+  for (NodeDef& n : *func->mutable_node_def()) {
+    std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from,
+                 to);
+  }
+
+  for (auto& p : *func->mutable_ret()) {
+    if (p.second == from) {
+      p.second = to;
+    }
+  }
+}
+
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+                                     StringPiece output_tensor_name,
+                                     FunctionDef* function, DataType dt) {
+  string name = string(prefix);
+  int id = function->signature().output_arg_size();
+  while (ContainsFunctionOutputWithName(name, *function)) {
+    name = strings::StrCat(prefix, "/_", id);
+    ++id;
+  }
+  auto* output = function->mutable_signature()->mutable_output_arg()->Add();
+  output->set_name(name);
+  output->set_type(dt);
+
+  (*function->mutable_ret())[name] = string(output_tensor_name);
+}
+
+NodeDef* AddNode(StringPiece name, StringPiece op,
+                 const std::vector<string>& inputs,
+                 const std::vector<std::pair<string, AttrValue>>& attributes,
+                 FunctionDef* fd) {
+  NodeDef* node = fd->add_node_def();
+  if (!name.empty()) {
+    node->set_name(string(name));
+  } else {
+    SetUniqueFunctionNodeName(op, fd, node);
+  }
+  node->set_op(string(op));
+  for (const string& input : inputs) {
+    node->add_input(input);
+  }
+  for (auto attr : attributes) {
+    (*node->mutable_attr())[attr.first] = attr.second;
+  }
+  return node;
+}
+
+bool ContainsFunctionNodeWithName(StringPiece name,
+                                  const FunctionDef& function) {
+  return FindFunctionNodeWithName(name, function) != -1;
+}
+
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+  return FindFunctionNodeWithOp(op, function) != -1;
+}
+
+bool ContainsFunctionOutputWithName(StringPiece name,
+                                    const FunctionDef& function) {
+  return FindFunctionOutputWithName(name, function) != -1;
+}
+
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
+  std::vector<int> indices = GetElementIndicesWithPredicate(
+      [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+      function.signature().input_arg());
+  return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
+  std::vector<int> indices = GetElementIndicesWithPredicate(
+      [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+      function.signature().output_arg());
+  return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
+  std::vector<int> indices = GetElementIndicesWithPredicate(
+      [&name](const NodeDef& node) { return node.name() == name; },
+      function.node_def());
+  return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+  std::vector<int> indices = GetElementIndicesWithPredicate(
+      [&op](const NodeDef& node) { return node.op() == op; },
+      function.node_def());
+
+  return indices.empty() ? -1 : indices.front();
+}
+
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+                               NodeDef* node) {
+  string name = string(prefix);
+  int id = function->node_def_size();
+  while (ContainsFunctionNodeWithName(name, *function)) {
+    name = strings::StrCat(prefix, "/_", id);
+    ++id;
+  }
+  node->set_name(std::move(name));
+}
+
+}  // end namespace function_utils
+}  // end namespace grappler
+}  // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.h b/tensorflow/core/grappler/optimizers/data/function_utils.h
new file mode 100644
index 0000000..d4ce824
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.h
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+// This namespace contains utility functions for querying and modifying
+// FunctionDefs.
+
+// Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings
+// have the format node_name:node_output:position (if they derive from nodes),
+// or input_name (if they derive from an argument).
+struct FunctionDefTensorDesc {
+  FunctionDefTensorDesc() = default;
+
+  FunctionDefTensorDesc(const string& node_name, const string& output,
+                        int position);
+
+  // Parses node_name:node_output:position string into its components.
+  explicit FunctionDefTensorDesc(const string& input);
+
+  // TODO(rachelim): Add provisions to deal with special formats, like how
+  // GrapplerFunctionItem expands node output range if position is not defined
+  string full_str;
+  string node_name;
+  string node_output;
+  int position = -1;
+};
+
+// Replaces all references to `from` tensor in func's nodes' inputs and retvals
+// to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`.
+void ReplaceReferences(const string& from, const string& to, FunctionDef* func);
+
+// Adds a function output to the function def, ensuring that the output key
+// is unique, and maps to output_tensor_name in the ret dict.
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+                                     StringPiece output_tensor_name,
+                                     FunctionDef* function, DataType dt);
+
+// Adds a node to a FunctionDef.
+NodeDef* AddNode(StringPiece name, StringPiece op,
+                 const std::vector<string>& inputs,
+                 const std::vector<std::pair<string, AttrValue>>& attributes,
+                 FunctionDef* fd);
+
+// Checks whether the function contains a node with the given name.
+bool ContainsFunctionNodeWithName(StringPiece name,
+                                  const FunctionDef& function);
+
+// Checks whether the function contains a node with the given op.
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Checks whether the function contains an output with the given name.
+bool ContainsFunctionOutputWithName(StringPiece name,
+                                    const FunctionDef& function);
+
+// Returns the index of the function input with the given name or -1 if the
+// function node does not exist.
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function output with the given name or -1 if the
+// function node does not exist.
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given name or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given op or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Sets the function node name using the `prefix` as a prefix while guaranteeing
+// the name is unique across the functions nodes.
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+                               NodeDef* node);
+
+}  // end namespace function_utils
+}  // end namespace grappler
+}  // end namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils_test.cc b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
new file mode 100644
index 0000000..3739e20
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
@@ -0,0 +1,164 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+namespace {
+
+TEST(FunctionDefTensorDesc, Parsing) {
+  FunctionDefTensorDesc f("Cast:y:0");
+  EXPECT_EQ(f.full_str, "Cast:y:0");
+  EXPECT_EQ(f.node_name, "Cast");
+  EXPECT_EQ(f.node_output, "y");
+  EXPECT_EQ(f.position, 0);
+
+  FunctionDefTensorDesc f2("Arg0");
+  EXPECT_EQ(f2.full_str, "Arg0");
+  EXPECT_EQ(f2.node_name, "Arg0");
+  EXPECT_EQ(f2.node_output, "");
+  EXPECT_EQ(f2.position, -1);
+}
+
+TEST(ReplaceReferencesTest, ReplaceReferencesTest) {
+  FunctionDef outer = FunctionDefHelper::Create(
+      "outer", {"arg0: int32"}, {"out: int32", "out2: int64"}, {}, {},
+      {{"out", "MapDefun:output:0"}, {"out2", "Cast:y:0"}});
+  NodeDef* derive_node =
+      AddNode("X", "Some_Op", {"MapDefun:output:0"}, {}, &outer);
+  // Check that both the input to "X" and retval of "outer" are replaced.
+  ReplaceReferences("MapDefun:output:0", "arg0", &outer);
+  EXPECT_EQ(outer.ret().at("out"), "arg0");
+  EXPECT_EQ(derive_node->input(0), "arg0");
+}
+
+TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) {
+  FunctionDef function = test::function::XTimesTwo();
+  AddFunctionOutputWithUniqueName("y", "two", &function, DT_INT64);
+  EXPECT_TRUE(ContainsFunctionOutputWithName("y/_1", function));
+  EXPECT_EQ(function.ret().at("y/_1"), "two");
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) {
+  FunctionDef function = test::function::XTimesTwo();
+  EXPECT_FALSE(ContainsFunctionNodeWithName(
+      "weird_name_that_should_not_be_there", function));
+  EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithOp) {
+  FunctionDef function = test::function::XTimesTwo();
+  EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
+                                          function));
+  EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionOutputWithName) {
+  FunctionDef function = test::function::XTimesTwo();
+  EXPECT_TRUE(ContainsFunctionOutputWithName("y", function));
+  EXPECT_FALSE(ContainsFunctionOutputWithName("Add:z:0", function));
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithName) {
+  FunctionDef function = test::function::XTimesTwo();
+  EXPECT_EQ(
+      FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
+      -1);
+  EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithOp) {
+  FunctionDef function = test::function::XTimesTwo();
+  EXPECT_EQ(
+      FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
+      -1);
+  EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionInputWithName) {
+  FunctionDef function = test::function::XTimesTwo();
+  EXPECT_EQ(FindFunctionInputWithName("x", function), 0);
+  EXPECT_EQ(FindFunctionInputWithName("not_a_name", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionOutputWithName) {
+  FunctionDef function = test::function::XTimesTwo();
+  EXPECT_EQ(FindFunctionOutputWithName("y", function), 0);
+  EXPECT_EQ(FindFunctionOutputWithName("Add:z:0", function), -1);
+}
+
+TEST(FunctionUtilsTest, SetUniqueFunctionNodeName) {
+  FunctionDef function = test::function::XTimesTwo();
+  NodeDef node;
+  SetUniqueFunctionNodeName("abc", &function, &node);
+  for (const NodeDef& function_node : function.node_def()) {
+    EXPECT_NE(node.name(), function_node.name());
+  }
+  auto* new_node = function.add_node_def();
+  *new_node = node;
+
+  NodeDef other;
+  SetUniqueFunctionNodeName("abc", &function, &other);
+  EXPECT_NE(other.name(), new_node->name());
+}
+
+TEST(FunctionUtilsTest, AddNodeToFunctionDef) {
+  FunctionDef func;
+  const char* op_name = "xxx";
+  AddNode(op_name, op_name, {}, {}, &func);
+
+  const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
+  EXPECT_EQ(node1.op(), op_name);
+  EXPECT_EQ(node1.input_size(), 0);
+  EXPECT_EQ(node1.attr_size(), 0);
+
+  const std::vector<string> inputs({"input1", "input2"});
+  AddNode("", op_name, inputs, {}, &func);
+  const NodeDef& node2 =
+      func.node_def(FindFunctionNodeWithName("xxx/_2", func));
+  EXPECT_EQ(node2.op(), op_name);
+  EXPECT_EQ(node2.attr_size(), 0);
+  EXPECT_EQ(node2.input_size(), inputs.size());
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    EXPECT_EQ(node2.input(i), inputs[i]);
+  }
+
+  AttrValue a1, a2;
+  a1.set_type(DT_INT32);
+  a2.set_type(DT_INT64);
+  const std::vector<std::pair<string, AttrValue>> attrs(
+      {{"attr1", a1}, {"attr2", a2}});
+  AddNode("", op_name, {}, attrs, &func);
+  const NodeDef& node3 =
+      func.node_def(FindFunctionNodeWithName("xxx/_3", func));
+  EXPECT_EQ(node3.op(), op_name);
+  EXPECT_EQ(node3.input_size(), 0);
+  EXPECT_EQ(node3.attr_size(), attrs.size());
+  for (size_t i = 0; i < attrs.size(); ++i) {
+    EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
+  }
+}
+
+}  // namespace
+}  // namespace function_utils
+}  // namespace grappler
+}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
index 01a78c0..b3bfee1 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
@@ -22,6 +22,7 @@
 #include "tensorflow/core/grappler/mutable_graph_view.h"
 #include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
 #include "tensorflow/core/grappler/utils.h"
 #include "tensorflow/core/lib/gtl/flatmap.h"
@@ -407,7 +408,7 @@
   auto* if_node = fused_function->add_node_def();
   // This is guaranteed to succeed.
   TF_CHECK_OK(if_builder.Finalize(if_node));
-  graph_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
+  function_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
 
   GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0";
 }
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
index d5c6466..e667aff 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
@@ -19,6 +19,7 @@
 #include "tensorflow/core/framework/function_testlib.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
 
 #include "tensorflow/core/lib/core/status_test_util.h"
@@ -110,9 +111,9 @@
   CheckUniqueNames(*fused_function);
 
   ASSERT_TRUE(
-      graph_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
+      function_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
   const auto &equal_node = fused_function->node_def(
-      graph_utils::FindFunctionNodeWithOp("Equal", *fused_function));
+      function_utils::FindFunctionNodeWithOp("Equal", *fused_function));
 
   EXPECT_EQ(xtimes_two->signature().output_arg(0).name(),
             fused_function->signature().output_arg(0).name());
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 5a7fe19..b3f60e3 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -108,26 +108,6 @@
   return graph->AddNode(std::move(node));
 }
 
-NodeDef* AddNode(StringPiece name, StringPiece op,
-                 const std::vector<string>& inputs,
-                 const std::vector<std::pair<string, AttrValue>>& attributes,
-                 FunctionDef* fd) {
-  NodeDef* node = fd->add_node_def();
-  if (!name.empty()) {
-    node->set_name(string(name));
-  } else {
-    SetUniqueFunctionNodeName(op, fd, node);
-  }
-  node->set_op(string(op));
-  for (const string& input : inputs) {
-    node->add_input(input);
-  }
-  for (auto attr : attributes) {
-    (*node->mutable_attr())[attr.first] = attr.second;
-  }
-  return node;
-}
-
 template <>
 NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
   return AddScalarConstNodeHelper(
@@ -196,6 +176,11 @@
   return true;
 }
 
+bool ContainsGraphFunctionWithName(StringPiece name,
+                                   const FunctionDefLibrary& library) {
+  return FindGraphFunctionWithName(name, library) != -1;
+}
+
 bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
   return FindGraphNodeWithName(name, graph) != -1;
 }
@@ -204,18 +189,14 @@
   return FindGraphNodeWithOp(op, graph) != -1;
 }
 
-bool ContainsGraphFunctionWithName(StringPiece name,
-                                   const FunctionDefLibrary& library) {
-  return FindGraphFunctionWithName(name, library) != -1;
-}
-
-bool ContainsFunctionNodeWithName(StringPiece name,
-                                  const FunctionDef& function) {
-  return FindFunctionNodeWithName(name, function) != -1;
-}
-
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
-  return FindFunctionNodeWithOp(op, function) != -1;
+int FindGraphFunctionWithName(StringPiece name,
+                              const FunctionDefLibrary& library) {
+  std::vector<int> indices = GetElementIndicesWithPredicate(
+      [&name](const FunctionDef& function) {
+        return function.signature().name() == name;
+      },
+      library.function());
+  return indices.empty() ? -1 : indices.front();
 }
 
 int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
@@ -237,31 +218,6 @@
       [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
 }
 
-int FindGraphFunctionWithName(StringPiece name,
-                              const FunctionDefLibrary& library) {
-  std::vector<int> indices = GetElementIndicesWithPredicate(
-      [&name](const FunctionDef& function) {
-        return function.signature().name() == name;
-      },
-      library.function());
-  return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
-  std::vector<int> indices = GetElementIndicesWithPredicate(
-      [&name](const NodeDef& node) { return node.name() == name; },
-      function.node_def());
-  return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
-  std::vector<int> indices = GetElementIndicesWithPredicate(
-      [&op](const NodeDef& node) { return node.op() == op; },
-      function.node_def());
-
-  return indices.empty() ? -1 : indices.front();
-}
-
 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
   if (node.input_size() == 0) return nullptr;
   GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
@@ -273,7 +229,7 @@
   string name = string(prefix);
   int id = graph->node_size();
   while (ContainsGraphNodeWithName(name, *graph)) {
-    if (name.rfind("_generated") != std::string::npos &&
+    if (name.rfind("_generated") != string::npos &&
         (name.rfind("_generated") == (name.size() - strlen("_generated")))) {
       name.insert(name.rfind("_generated"), strings::StrCat("/_", id));
     } else {
@@ -284,17 +240,6 @@
   node->set_name(std::move(name));
 }
 
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
-                               NodeDef* node) {
-  string name = string(prefix);
-  int id = function->node_def_size();
-  while (ContainsFunctionNodeWithName(name, *function)) {
-    name = strings::StrCat(prefix, "/_", id);
-    ++id;
-  }
-  node->set_name(std::move(name));
-}
-
 void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
                                 FunctionDef* function) {
   string name = string(prefix);
@@ -305,7 +250,6 @@
   }
   function->mutable_signature()->set_name(std::move(name));
 }
-
 }  // end namespace graph_utils
 }  // end namespace grappler
 }  // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 6f431c2..1652afc 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -37,12 +37,6 @@
                  const std::vector<std::pair<string, AttrValue>>& attributes,
                  MutableGraphView* graph);
 
-// Adds a node to a FunctionDef.
-NodeDef* AddNode(StringPiece name, StringPiece op,
-                 const std::vector<string>& inputs,
-                 const std::vector<std::pair<string, AttrValue>>& attributes,
-                 FunctionDef* fd);
-
 // Adds a Const node with the given value to the graph.
 template <typename T>
 NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
@@ -76,13 +70,6 @@
 bool ContainsGraphFunctionWithName(StringPiece name,
                                    const FunctionDefLibrary& library);
 
-// Checks whether the function contains a node with the given name.
-bool ContainsFunctionNodeWithName(StringPiece name,
-                                  const FunctionDef& function);
-
-// Checks whether the function contains a node with the given op.
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
 // Checks whether the graph contains a node with the given op.
 bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph);
 
@@ -95,14 +82,6 @@
 int FindGraphFunctionWithName(StringPiece name,
                               const FunctionDefLibrary& library);
 
-// Returns the index of the function node with the given name or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
-
-// Returns the index of the function node with the given op or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
 // Returns the index of the first node with the given op or -1 if no such  node
 // exists.
 int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
@@ -119,11 +98,6 @@
 // is unique across the graph.
 void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
 
-// Sets the function node name using the `prefix` as a prefix while guaranteeing
-// the name is unique across the functions nodes.
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
-                               NodeDef* node);
-
 // Sets the node name using the `prefix` name as a prefix while guaranteeing the
 // name is unique across the graph.
 void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index c19ac7b..6877c20 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -112,20 +112,6 @@
       ContainsGraphFunctionWithName(new_function->signature().name(), library));
 }
 
-TEST(GraphUtilsTest, ContainsFunctionNodeWithName) {
-  FunctionDef function = test::function::XTimesTwo();
-  EXPECT_FALSE(ContainsFunctionNodeWithName(
-      "weird_name_that_should_not_be_there", function));
-  EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
-}
-
-TEST(GraphUtilsTest, ContainsFunctionNodeWithOp) {
-  FunctionDef function = test::function::XTimesTwo();
-  EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
-                                          function));
-  EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
-}
-
 TEST(GraphUtilsTest, ContainsNodeWithOp) {
   GraphDef graph_def;
   MutableGraphView graph(&graph_def);
@@ -150,22 +136,6 @@
   EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
 }
 
-TEST(GraphUtilsTest, FindFunctionNodeWithName) {
-  FunctionDef function = test::function::XTimesTwo();
-  EXPECT_EQ(
-      FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
-      -1);
-  EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
-}
-
-TEST(GraphUtilsTest, FindFunctionNodeWithOp) {
-  FunctionDef function = test::function::XTimesTwo();
-  EXPECT_EQ(
-      FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
-      -1);
-  EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
-}
-
 TEST(GraphUtilsTest, FindGraphFunctionWithName) {
   FunctionDefLibrary library;
   EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1);
@@ -225,21 +195,6 @@
   EXPECT_NE(node2->name(), node3->name());
 }
 
-TEST(GraphUtilsTest, SetUniqueFunctionNodeName) {
-  FunctionDef function = test::function::XTimesTwo();
-  NodeDef node;
-  SetUniqueFunctionNodeName("abc", &function, &node);
-  for (const NodeDef& function_node : function.node_def()) {
-    EXPECT_NE(node.name(), function_node.name());
-  }
-  auto* new_node = function.add_node_def();
-  *new_node = node;
-
-  NodeDef other;
-  SetUniqueFunctionNodeName("abc", &function, &other);
-  EXPECT_NE(other.name(), new_node->name());
-}
-
 TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
   FunctionDefLibrary library;
   FunctionDef* new_function = library.add_function();
@@ -251,43 +206,6 @@
             other_function->signature().name());
 }
 
-TEST(GraphUtilsTest, AddNodeToFunctionDef) {
-  FunctionDef func;
-  const char* op_name = "xxx";
-  AddNode(op_name, op_name, {}, {}, &func);
-
-  const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
-  EXPECT_EQ(node1.op(), op_name);
-  EXPECT_EQ(node1.input_size(), 0);
-  EXPECT_EQ(node1.attr_size(), 0);
-
-  const std::vector<string> inputs({"input1", "input2"});
-  AddNode("", op_name, inputs, {}, &func);
-  const NodeDef& node2 =
-      func.node_def(FindFunctionNodeWithName("xxx/_2", func));
-  EXPECT_EQ(node2.op(), op_name);
-  EXPECT_EQ(node2.attr_size(), 0);
-  EXPECT_EQ(node2.input_size(), inputs.size());
-  for (size_t i = 0; i < inputs.size(); ++i) {
-    EXPECT_EQ(node2.input(i), inputs[i]);
-  }
-
-  AttrValue a1, a2;
-  a1.set_type(DT_INT32);
-  a2.set_type(DT_INT64);
-  const std::vector<std::pair<string, AttrValue>> attrs(
-      {{"attr1", a1}, {"attr2", a2}});
-  AddNode("", op_name, {}, attrs, &func);
-  const NodeDef& node3 =
-      func.node_def(FindFunctionNodeWithName("xxx/_3", func));
-  EXPECT_EQ(node3.op(), op_name);
-  EXPECT_EQ(node3.input_size(), 0);
-  EXPECT_EQ(node3.attr_size(), attrs.size());
-  for (size_t i = 0; i < attrs.size(); ++i) {
-    EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
-  }
-}
-
 TEST(GraphUtilsTest, GetInputNode) {
   GraphDef graph_def;
   MutableGraphView graph(&graph_def);
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
new file mode 100644
index 0000000..305325e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
@@ -0,0 +1,106 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h"
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+bool CanParallelize(const FunctionDef& function,
+                    const FunctionLibraryDefinition& library) {
+  if (!function.signature().is_stateful()) return true;
+
+  for (const auto& node : function.node_def()) {
+    const OpDef* op_def;
+    TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def));
+    // Assert is marked as stateful, but it does not have any state (except
+    // changing io).  Similarly to CUDA, we do not give guarantee that the
+    // assert operation that would fail would be the first one, so that we can
+    // parallelize it.
+    if (op_def->is_stateful() && op_def->name() != "Assert") return false;
+  }
+
+  return true;
+}
+
+NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) {
+  NodeDef parallel_map = map_node;
+  graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(),
+                                      &parallel_map);
+  parallel_map.set_op("ParallelMapDataset");
+  // TODO(b/114475558): We want to set `num_parallel_calls` to a special value,
+  // so that dynamic tunning will pick the optimal value at runtime. Because
+  // this feature is not yet implemented, we set it to 2, which is the smallest
+  // value that introduces parallelism.
+  auto* num_parallel_calls = graph_utils::AddScalarConstNode(2, graph);
+  parallel_map.add_input(num_parallel_calls->name());
+
+  return parallel_map;
+}
+
+}  // namespace
+
+Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item,
+                                    GraphDef* output) {
+  *output = item.graph;
+  MutableGraphView graph(output);
+  std::set<string> nodes_to_delete;
+  FunctionLibraryDefinition function_library(OpRegistry::Global(),
+                                             item.graph.library());
+  auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
+    if (node.op() == "MapDataset") return &node;
+    return nullptr;
+  };
+
+  for (const NodeDef& node : item.graph.node()) {
+    const NodeDef* map_node = get_map_node(node);
+    if (!map_node) continue;
+
+    auto* function =
+        function_library.Find(map_node->attr().at("f").func().name());
+    if (!CanParallelize(*function, function_library)) continue;
+
+    auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph));
+    graph.ReplaceInput(*map_node, *parallel_map);
+
+    // TODO(prazek): we could also remove map functions from library if they
+    // are not used anymore.
+    nodes_to_delete.insert(map_node->name());
+  }
+
+  graph.DeleteNodes(nodes_to_delete);
+  return Status::OK();
+}
+
+void MapParallelization::Feedback(Cluster* cluster, const GrapplerItem& item,
+                                  const GraphDef& optimize_output,
+                                  double result) {
+  // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(MapParallelization, "map_parallelization");
+
+}  // end namespace grappler
+}  // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.h b/tensorflow/core/grappler/optimizers/data/map_parallelization.h
new file mode 100644
index 0000000..ac9cf7e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimization parallelizes MapDataset when function is stateless.
+class MapParallelization : public CustomGraphOptimizer {
+ public:
+  MapParallelization() = default;
+  ~MapParallelization() override = default;
+
+  string name() const override { return "map_parallelization"; };
+
+  Status Init(
+      const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+    return Status::OK();
+  }
+
+  Status Optimize(Cluster* cluster, const GrapplerItem& item,
+                  GraphDef* output) override;
+
+  void Feedback(Cluster* cluster, const GrapplerItem& item,
+                const GraphDef& optimize_output, double result) override;
+};
+
+}  // end namespace grappler
+}  // end namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
new file mode 100644
index 0000000..b2a5d9b
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
@@ -0,0 +1,94 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+                    StringPiece function_name) {
+  return test::function::NDef(
+      name, "MapDataset", {string(input_node_name)},
+      {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
+       {"Targuments", {}},
+       {"output_shapes", {}},
+       {"output_types", {}}});
+}
+
+const char stateless_fun_name[] = "XTimesTwo";
+const char stateful_fun_name[] = "RandomUniform";
+
+TEST(MapParallelizationTest, ParallelizeSimpleMap) {
+  using test::function::NDef;
+  GrapplerItem item;
+  item.graph = test::function::GDef(
+      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+       NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+       NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+       MakeMapNode("map1", "range", stateless_fun_name)},
+      // FunctionLib
+      {
+          test::function::XTimesTwo(),
+      });
+
+  MapParallelization optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+  EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output));
+  EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
+  EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
+}
+
+TEST(MapParallelization, ParallelizeAssert) {
+  using test::function::NDef;
+  GrapplerItem item;
+  item.graph = test::function::GDef(
+      {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+       NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+       NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+       NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+       NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+       MakeMapNode("map1", "range", stateful_fun_name),
+       MakeMapNode("map2", "map1", stateless_fun_name),
+       NDef("cache", "CacheDataset", {"map2", "filename"}, {})},
+      // FunctionLib
+      {
+          test::function::XTimesTwo(),
+          test::function::RandomUniform(),
+      });
+
+  MapParallelization optimizer;
+  GraphDef output;
+  TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+  EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output));
+  EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map1", output));
+  EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
+}
+
+}  // namespace
+}  // namespace grappler
+}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index a019b77..7a2f191 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -14,6 +14,7 @@
 ==============================================================================*/
 
 #include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
 
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
@@ -24,6 +25,7 @@
 #include "tensorflow/core/grappler/mutable_graph_view.h"
 #include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
 #include "tensorflow/core/grappler/utils.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
@@ -37,11 +39,11 @@
   (*to->mutable_attr())[attr_name] = from.attr().at(attr_name);
 }
 
-FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
+// Returns a FunctionDef containing a MapDefun op that wraps the original
+// function.
+FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
                                    const FunctionDef& orig_func,
                                    FunctionDefLibrary* library) {
-  // If we decide to use a different method of vectorization, we can just
-  // swap out this part.
   FunctionDef* vectorized_func = library->add_function();
   // Function inputs and outputs are the same as original, just
   // with different shapes.
@@ -52,8 +54,8 @@
   // Add MapDefun node
   NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Add();
   map_defun_node->set_op("MapDefun");
-  graph_utils::SetUniqueFunctionNodeName(map_defun_node->op(), vectorized_func,
-                                         map_defun_node);
+  function_utils::SetUniqueFunctionNodeName(map_defun_node->op(),
+                                            vectorized_func, map_defun_node);
 
   // Set attrs and inputs
   for (const string& k : {"f", "output_types", "output_shapes"}) {
@@ -81,6 +83,30 @@
   return vectorized_func;
 }
 
+FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
+                                   const FunctionDef& orig_func,
+                                   FunctionDefLibrary* library) {
+  // Vectorizes orig_func naively by wrapping in a MapDefun op, then performing
+  // efficient vectorization with VectorizeMapDefun.
+  FunctionDef* vectorized_func =
+      CreateMapDefunWrapper(map_node, orig_func, library);
+  NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0);
+  DCHECK_EQ(map_defun_node->op(), "MapDefun");
+
+  // Create a copy of the original function so that we can mutate it, and
+  // attach that to the map defun node.
+  FunctionDef* map_defun_fn = library->add_function();
+  *map_defun_fn = orig_func;
+  graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library,
+                                          map_defun_fn);
+  (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name(
+      map_defun_fn->signature().name());
+
+  vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn,
+                                         map_defun_node);
+  return vectorized_func;
+}
+
 bool IsOutputShapesFullyDefined(const NodeDef& node) {
   auto* shapes_attr = gtl::FindOrNull(node.attr(), "output_shapes");
   if (shapes_attr == nullptr) return false;
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
new file mode 100644
index 0000000..bfca63b
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -0,0 +1,346 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+
+#include "absl/strings/str_join.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/functions.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/scanner.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+using function_utils::FunctionDefTensorDesc;
+
+namespace {
+
+void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node,
+                       const string& output_retval, const DataType t) {
+  // Set to unknown shape
+  TensorShapeProto tensor_shape_proto;
+  PartialTensorShape().AsProto(&tensor_shape_proto);
+
+  function_utils::AddFunctionOutputWithUniqueName(
+      "vectorized_out", output_retval, map_defun_fn, t);
+
+  *(*map_defun_node->mutable_attr())["output_shapes"]
+       .mutable_list()
+       ->add_shape() = tensor_shape_proto;
+  (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t);
+}
+
+void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+                          NodeDef* map_defun_node, int output_position) {
+  DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size())
+      << "Trying to remove output that doesn't exist. Output number: "
+      << output_position;
+
+  int num_later_outputs =
+      map_defun_fn->signature().output_arg_size() - output_position - 1;
+
+  // Remove from map_defun_fn's ret dict and output args
+  map_defun_fn->mutable_ret()->erase(
+      map_defun_fn->signature().output_arg(output_position).name());
+  map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange(
+      output_position, 1);
+
+  // Renumber outputs that come after
+  for (int i = 0; i < num_later_outputs; ++i) {
+    function_utils::ReplaceReferences(
+        strings::StrCat(map_defun_node->name(),
+                        ":output:", output_position + i + 1),
+        strings::StrCat(map_defun_node->name(),
+                        ":output:", output_position + i),
+        outer_scope);
+  }
+  map_defun_node->mutable_attr()
+      ->at("output_shapes")
+      .mutable_list()
+      ->mutable_shape()
+      ->DeleteSubrange(output_position, 1);
+  map_defun_node->mutable_attr()
+      ->at("output_types")
+      .mutable_list()
+      ->mutable_type()
+      ->ExtractSubrange(output_position, 1, nullptr);
+}
+
+Status ConvertCastOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs,
+                     const NodeDef& cast_node,
+                     std::map<string, string>* conversion_map) {
+  if (inputs.size() != 1) {
+    return errors::Internal("Cast op should only have one input.");
+  }
+
+  // Add new Cast node
+  NodeDef* new_cast_node = outer_scope->add_node_def();
+  *new_cast_node = cast_node;
+  new_cast_node->clear_name();
+  function_utils::SetUniqueFunctionNodeName(
+      strings::StrCat("vectorized/", cast_node.name()), outer_scope,
+      new_cast_node);
+  new_cast_node->set_input(0, inputs[0]);
+
+  // Add the output mapping to conversion map
+  (*conversion_map)[strings::StrCat(cast_node.name(), ":y:0")] =
+      strings::StrCat(new_cast_node->name(), ":y:0");
+
+  return Status::OK();
+}
+
+Status ConvertUnpackOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs,
+                       const NodeDef& unpack_node,
+                       std::map<string, string>* conversion_map) {
+  if (inputs.size() != 1) {
+    return errors::Internal("Unpack op should only have one input.");
+  }
+
+  // Add new Unpack node
+  NodeDef* new_unpack_node = outer_scope->add_node_def();
+  *new_unpack_node = unpack_node;
+  new_unpack_node->clear_name();
+  function_utils::SetUniqueFunctionNodeName(
+      strings::StrCat("vectorized/", unpack_node.name()), outer_scope,
+      new_unpack_node);
+
+  // Increment "axis" attr by 1:
+  (*new_unpack_node->mutable_attr())["axis"].set_i(
+      unpack_node.attr().at("axis").i() + 1);
+  new_unpack_node->set_input(0, inputs[0]);
+
+  // Add the output mappings to conversion map
+  int num = new_unpack_node->attr().at("num").i();
+  for (int i = 0; i < num; ++i) {
+    (*conversion_map)[strings::StrCat(unpack_node.name(), ":output:", i)] =
+        strings::StrCat(new_unpack_node->name(), ":output:", i);
+  }
+
+  return Status::OK();
+}
+
+int FindOutputToConvert(const FunctionDef& function,
+                        const std::set<string>& unconvertible,
+                        FunctionDefTensorDesc* f) {
+  for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) {
+    const string& ret_key = function.signature().output_arg(i).name();
+    *f = FunctionDefTensorDesc(function.ret().at(ret_key));
+
+    if (unconvertible.find(f->node_name) == unconvertible.end()) {
+      return i;
+    }
+  }
+  return -1;
+}
+
+// Helper class that vectorizes the body of a MapDefun node, adding new
+// operations to the graph that collectively compute the same value as what
+// running the MapDefun function on slices of the input would produce.
+// Each instance of the class encapsulates all the data necessary to vectorize a
+// MapDefun op in place.
+class Vectorization {
+ public:
+  Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+                NodeDef* map_defun_node)
+      : outer_scope_(outer_scope),
+        map_defun_fn_(map_defun_fn),
+        map_defun_node_(map_defun_node) {}
+
+  // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in
+  // the outer_scope_, until there are no convertible outputs remaining.
+  // This method is idempotent.
+  void Vectorize();
+
+ private:
+  // Vectorizes the map defun function's output at output_position
+  Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc);
+  // Given a descriptor of the original output tensor, gets a string
+  // corresponding to the converted output tensor.
+  Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc,
+                             string* converted);
+  Status AddConversionMappingFromInput(
+      const FunctionDefTensorDesc& output_desc);
+
+  // Adds mappings from node's outputs tensors to converted output tensors,
+  // creating the necessary new node(s). Generally, the steps to convert an op
+  // are:
+  // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_,
+  //    and modify map_defun_node_ attrs accordingly
+  // 2) Create new node(s) in outer_scope_ that act on batched input tensors.
+  //    These operations collectively compute the same value as what running
+  //    the original operation on slices of the input tensors would produce.
+  //    For example, a Cast op in MapDefun translates to a Cast op in
+  //    outer_scope_, since the vectorized version of Cast is itself.
+  // 3) Set inputs of new node(s) to the corresponding converted inputs (that
+  //    are now outputs of map_defun_node_)
+  // 4) For each output of the old node, add the mapping of output strings to
+  //    the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0")
+  Status AddConversionMappingFromOp(const NodeDef& node,
+                                    const FunctionDefTensorDesc& output_desc);
+
+  // Maps a tensor name to the name of the corresponding vectorized tensor. For
+  // example, "Cast:y:0" -> "Vectorize/Cast:y:0"
+  std::map<string, string> conversion_map_;
+  // Unconvertible node names
+  std::set<string> unconvertible_;
+
+  FunctionDef* outer_scope_;
+  FunctionDef* map_defun_fn_;
+  NodeDef* map_defun_node_;
+};
+
+Status Vectorization::AddConversionMappingFromOp(
+    const NodeDef& node, const FunctionDefTensorDesc& output_desc) {
+  for (const string& input_name : node.input()) {
+    if (IsControlInput(input_name)) {
+      return errors::InvalidArgument(
+          "Vectorizing outputs with control inputs is currently not "
+          "supported.");
+    }
+  }
+
+  // TODO(rachelim): Have some mechanism for registering converters and some
+  // uniform, simpler way to represent them.
+
+  DataTypeVector types;
+  const OpDef* op_def = nullptr;
+  TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
+  TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types));
+
+  std::vector<string> promoted_inputs;
+  promoted_inputs.reserve(node.input_size());
+  for (int i = 0; i < node.input_size(); ++i) {
+    promoted_inputs.push_back(strings::StrCat(
+        map_defun_node_->name(),
+        ":output:", map_defun_fn_->signature().output_arg_size() + i));
+  }
+
+  if (node.op() == "Cast") {
+    TF_RETURN_IF_ERROR(
+        ConvertCastOp(outer_scope_, promoted_inputs, node, &conversion_map_));
+  } else if (node.op() == "Unpack") {
+    TF_RETURN_IF_ERROR(
+        ConvertUnpackOp(outer_scope_, promoted_inputs, node, &conversion_map_));
+  } else {
+    return errors::Unimplemented("Op converter for \"", node.op(),
+                                 "\" not implemented yet");
+  }
+
+  // If we get here, the conversion was successful, so we promote the inputs
+  // of the ops to MapDefun outputs.
+  for (int i = 0; i < types.size(); ++i) {
+    AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]);
+  }
+
+  return Status::OK();
+}
+
+Status Vectorization::AddConversionMappingFromInput(
+    const FunctionDefTensorDesc& output_desc) {
+  int input_index = function_utils::FindFunctionInputWithName(
+      output_desc.node_name, *map_defun_fn_);
+  if (input_index == -1) {
+    return errors::Internal("Cannot convert non-existent input.");
+  }
+
+  conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index);
+  return Status::OK();
+}
+
+Status Vectorization::ConvertOutputHelper(
+    const FunctionDefTensorDesc& output_desc, string* converted) {
+  // It's possible the output already has a mapping, if it comes from a node
+  // that has already been converted.
+  if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) {
+    *converted = *found;
+    return Status::OK();
+  }
+
+  int index = function_utils::FindFunctionNodeWithName(output_desc.node_name,
+                                                       *map_defun_fn_);
+  if (index == -1) {  // The output comes from an input
+    TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc));
+  } else {
+    TF_RETURN_IF_ERROR(AddConversionMappingFromOp(
+        map_defun_fn_->node_def(index), output_desc));
+  }
+  *converted = conversion_map_.at(output_desc.full_str);
+  return Status::OK();
+}
+
+Status Vectorization::ConvertOutput(int output_position,
+                                    const FunctionDefTensorDesc& output_desc) {
+  string converted_output_name;
+  TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name));
+
+  // Remove the old output and make everything that referenced it point
+  // to the new string
+  function_utils::ReplaceReferences(
+      strings::StrCat(map_defun_node_->name(), ":output:", output_position),
+      converted_output_name, outer_scope_);
+  RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_,
+                       output_position);
+
+  return Status::OK();
+}
+
+void Vectorization::Vectorize() {
+  while (true) {
+    FunctionDefTensorDesc desc;
+    int output_position =
+        FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc);
+    if (output_position == -1) break;
+
+    if (!ConvertOutput(output_position, desc).ok()) {
+      unconvertible_.insert(desc.node_name);
+    }
+  }
+
+  // If we've converted all the outputs of the MapDefun function, we no longer
+  // need the MapDefun node and can delete it.
+  if (map_defun_fn_->signature().output_arg_size() == 0) {
+    outer_scope_->mutable_node_def()->DeleteSubrange(
+        function_utils::FindFunctionNodeWithName(map_defun_node_->name(),
+                                                 *outer_scope_),
+        1);
+  }
+
+  if (!unconvertible_.empty()) {
+    VLOG(2) << "The following nodes could not be converted: ["
+            << absl::StrJoin(unconvertible_, ", ") << "].";
+  }
+}
+}  // namespace
+
+void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+                       NodeDef* map_defun_node) {
+  Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
+}
+
+}  // end namespace vectorization_utils
+}  // end namespace grappler
+}  // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
new file mode 100644
index 0000000..bb405fa
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// Given a function, `map_defun_fn`, that is mapped across some input vector
+// elements via a MapDefun operation, `VectorizeMapDefun` attempts to
+// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the
+// `outer_scope`; that is, replacing `map_defun_fn` operations with new
+// `outer_scope` operations that produce the same vector output(s) as executing
+// the `map_defun_fn` operations on elements of vector input(s) would. If all
+// `map_defun_fn` operations are successfully lifted, `map_defun_node` is
+// eliminated from `outer_scope` altogether. However, if some operations cannot
+// be lifted, and this vectorization only succeeds partially, `map_defun_node`
+// remains to be used for operations that were not lifted.
+//
+// Example:
+//   If the input to the `VectorizeMapDefun` function is a MapDefun
+// whose `map_defun_fn` performs the Cast operation, the vectorization will
+// eliminate the MapDefun. This is because the Cast operation supports
+// any tensor shape and can thus be lifted to the `outer_scope`.
+//
+// Before:
+//
+//
+// outer_scope     +------+
+// +---------------+ Arg0 +---------+
+// |               +---+--+         |
+// |                   |            |
+// |  map_defun_fn +---v--+         |
+// |   +-----------+ Arg0 +-----+   |
+// |   |           +---+--+     |   |
+// |   |               |        |   |
+// |   |               |        |   |
+// |   |           +---v--+     |   |
+// |   |           | Cast |     |   |
+// |   |           +---+--+     |   |
+// |   |               |        |   |
+// |   |           +---v--+     |   |
+// |   +-----------+ Ret0 +-----+   |
+// |               +---+--+         |
+// |                   |            |
+// |               +---v--+         |
+// +---------------+ Ret0 +---------+
+//                 +------+
+//
+//
+// After:
+//
+// outer_scope     +------+
+// +---------------+ Arg0 +---------+
+// |               +---+--+         |
+// |                   |            |
+// |               +---v--+         |
+// |               | Cast |         |
+// |               +---+--+         |
+// |                   |            |
+// |               +---v--+         |
+// +---------------+ Ret0 +---------+
+//                 +------+
+//
+void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+                       NodeDef* map_defun_node);
+
+}  // end namespace vectorization_utils
+}  // end namespace grappler
+}  // end namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
new file mode 100644
index 0000000..e129fa9
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -0,0 +1,600 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+namespace {
+
+NodeDef* AddCastNode(const string& name, const std::vector<string>& inputs,
+                     DataType src, DataType dst, bool truncate,
+                     FunctionDef* fn) {
+  NodeDef* node = function_utils::AddNode(name, "Cast", inputs, {}, fn);
+  graph_transforms::SetNodeAttr("SrcT", src, node);
+  graph_transforms::SetNodeAttr("DstT", dst, node);
+  graph_transforms::SetNodeAttr("Truncate", truncate, node);
+  return node;
+}
+
+NodeDef* AddUnstackNode(const string& name, const std::vector<string>& inputs,
+                        DataType t, int axis, int num, FunctionDef* fn) {
+  NodeDef* node = function_utils::AddNode(name, "Unpack", inputs, {}, fn);
+  graph_transforms::SetNodeAttr("T", t, node);
+  graph_transforms::SetNodeAttr("axis", axis, node);
+  graph_transforms::SetNodeAttr("num", num, node);
+  return node;
+}
+
+NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
+                         const std::vector<DataType>& t_arguments,
+                         const std::vector<DataType>& output_types,
+                         const std::vector<TensorShape>& output_shapes,
+                         const string& function_name, FunctionDef* fn) {
+  NameAttrList func;
+  func.set_name(function_name);
+  NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
+  graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
+  graph_transforms::SetNodeAttr("output_types", output_types, node);
+  graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
+  graph_transforms::SetNodeAttr("f", func, node);
+  return node;
+}
+
+// TODO(rachelim): Use FunctionDefHelper::Create instead
+FunctionDef CreateFunction(
+    StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
+    const std::vector<std::pair<string, DataType>>& outputs,
+    const std::map<string, string>& rets) {
+  FunctionDef func;
+  auto* signature = func.mutable_signature();
+  signature->set_name(string(name));
+  for (const auto& x : inputs) {
+    auto* arg_def = signature->add_input_arg();
+    arg_def->set_name(x.first);
+    arg_def->set_type(x.second);
+  }
+  for (const auto& x : outputs) {
+    auto* arg_def = signature->add_output_arg();
+    arg_def->set_name(x.first);
+    arg_def->set_type(x.second);
+  }
+  for (const auto& x : rets) {
+    (*func.mutable_ret())[x.first] = x.second;
+  }
+
+  return func;
+}
+
+TEST(FunctionDefInputDescTest, ConstructedCorrectly) {}
+
+// Before:
+//
+//                 +------+   +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// |               +---+--+   +---+--+        |
+// |                   |          |           |
+// |               +---v--+   +---v--+        |
+// |   +-----------+ Arg0 +---+ Arg1 +----+   |
+// |   |           +---+--+   +---+--+    |   |
+// |   |               |          |       |   |
+// |   | MapDefun  +---v--+   +---v--+    |   |
+// |   +-----------+ Ret0 +---+ Ret1 +----+   |
+// |               +---+--+   +---+--+        |
+// |                   |          |           |
+// |               +---v--+   +---v--+        |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+//                 +------+   +------+
+//
+//
+//  After:
+//
+//                 +------+   +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// |               +---+--+   +---+--+        |
+// |                   |          |           |
+// |                   |          |           |
+// |                   |          |           |
+// |               +---v--+   +---v--+        |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+//                 +------+   +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
+  FunctionDef inner =
+      CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+                     {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+                     {{"ret0", "arg0"}, {"ret1", "arg1"}});
+  FunctionDef outer = CreateFunction(
+      "outer_function", {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+      {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+      {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+  NodeDef* map_defun = AddMapDefunNode(
+      "MapDefun", {"ret0", "ret1"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+      {{}, {}}, inner.signature().name(), &outer);
+  CHECK_NOTNULL(map_defun);
+
+  VectorizeMapDefun(&outer, &inner, map_defun);
+  EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+  EXPECT_EQ(outer.ret().at("mapdefun"), "ret0");
+  EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1");
+}
+
+// Before:
+//
+//                 +------+   +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// |               +---+--+   +---+--+        |
+// |                   |          |           |
+// |               +---v--+   +---v--+        |
+// |   +-----------+ Arg0 +---+ Arg1 +----+   |
+// |   |           +---+--+   +---+--+    |   |
+// |   |               |          |       |   |
+// |   |   +------+    |      +---v--+    |   |
+// |   |   |Const |    |      | Op0  |    |   |
+// |   |   +---v--+    |      +---+--+    |   |
+// |   |       |       |          |       |   |
+// |   |       |   +---v--+   +---v--+    |   |
+// |   |       +---| XOp1 |   | XOp2 |    |   |
+// |   |           +---+--+   +---+--+    |   |
+// |   |               |          |       |   |
+// |   | MapDefun  +---v--+   +---v--+    |   |
+// |   +-----------+ Ret0 +---+ Ret1 +----+   |
+// |               +---+--+   +---+--+        |
+// |                   |          |           |
+// |               +---v--+   +---v--+        |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+//                 +------+   +------+
+//
+//   where XOp1 and XOp2 are not convertible.
+//
+// After:
+//
+// No change because the ops are not convertible.
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
+  FunctionDef inner =
+      CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+                     {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+                     {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}});
+  NodeDef* x_op1 =
+      function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner);
+  CHECK_NOTNULL(x_op1);
+
+  NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner);
+  CHECK_NOTNULL(x_op2);
+
+  FunctionDef outer = CreateFunction(
+      "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
+      {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+      {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+  NodeDef* map_defun = AddMapDefunNode(
+      "MapDefun", {"x", "y"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+      {{}, {}}, inner.signature().name(), &outer);
+  CHECK_NOTNULL(map_defun);
+
+  FunctionDef outer_copy(outer);
+  FunctionDef inner_copy(inner);
+  VectorizeMapDefun(&outer, &inner, map_defun);
+  // They should be unchanged
+  EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+  EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner));
+}
+
+// Before:
+//
+//
+//                 +------+
+// +---------------+ Arg0 +---------+
+// |               +---+--+         |
+// |                   |            |
+// |               +---v--+         |
+// |   +-----------+ Arg0 +-----+   |
+// |   |           +---+--+     |   |
+// |   |               |        |   |
+// |   |               |        |   |
+// |   |           +---v--+     |   |
+// |   |           | Cast |     |   |
+// |   |           +---+--+     |   |
+// |   |               |        |   |
+// |   | MapDefun  +---v--+     |   |
+// |   +-----------+ Ret0 +-----+   |
+// |               +---+--+         |
+// |                   |            |
+// |               +---v--+         |
+// +---------------+ Ret0 +---------+
+//                 +------+
+//
+//
+//  After:
+//
+//                 +------+
+// +---------------+ Arg0 +---------+
+// |               +---+--+         |
+// |                   |            |
+// |               +---v--+         |
+// |               | Cast |         |
+// |               +---+--+         |
+// |                   |            |
+// |               +---v--+         |
+// +---------------+ Ret0 +---------+
+//                 +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
+  FunctionDef inner =
+      CreateFunction("inner_function", {{"arg0", DT_INT32}},
+                     {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+  NodeDef* cast_op =
+      AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+  CHECK_NOTNULL(cast_op);
+
+  FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+                                     {{"mapdefun", DT_INT64}},
+                                     {{"mapdefun", "MapDefun:output:0"}});
+
+  NodeDef* map_defun =
+      AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+                      inner.signature().name(), &outer);
+  CHECK_NOTNULL(map_defun);
+
+  VectorizeMapDefun(&outer, &inner, map_defun);
+  EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+  const NodeDef& cast_node =
+      outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+  EXPECT_EQ(cast_node.input(0), "x");
+  EXPECT_EQ(outer.ret().at("mapdefun"),
+            strings::StrCat(cast_node.name(), ":y:0"));
+  EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+//                 +------+
+// +---------------+ Arg0 +-------------------+
+// |               +---+--+                   |
+// |                   |                      |
+// |               +---v--+                   |
+// |   +-----------+ Arg0 +---------------+   |
+// |   |           +---+--+               |   |
+// |   |               |                  |   |
+// |   |               |                  |   |
+// |   |           +---v--+               |   |
+// |   |           | Cast |               |   |
+// |   |           +---+--+               |   |
+// |   |               |                  |   |
+// |   |               +----------+       |   |
+// |   |               |          |       |   |
+// |   | MapDefun  +---v--+   +---v--+    |   |
+// |   +-----------+ Ret0 +---+ Ret1 +----+   |
+// |               +---+--+   +---+--+        |
+// |                   |          |           |
+// |               +---v--+   +---v--+        |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+//                 +------+   +------+
+//
+//
+//  After:
+//
+//                 +------+
+// +---------------+ Arg0 +-------------------+
+// |               +---+--+                   |
+// |                   |                      |
+// |                   |                      |
+// |               +---v--+                   |
+// |               | Cast |                   |
+// |               +---+--+                   |
+// |                   |                      |
+// |                   +----------+           |
+// |                   |          |           |
+// |               +---v--+   +---v--+        |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+//                 +------+   +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
+  // Tests that behavior is correct when an output is used more than once.
+  FunctionDef inner =
+      CreateFunction("inner_function", {{"arg0", DT_INT32}},
+                     {{"ret0", DT_INT64}, {"ret1", DT_INT64}},
+                     {{"ret0", "Cast:y:0"}, {"ret1", "Cast:y:0"}});
+  NodeDef* cast_op =
+      AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+  CHECK_NOTNULL(cast_op);
+
+  FunctionDef outer = CreateFunction(
+      "outer_function", {{"x", DT_INT32}},
+      {{"mapdefun", DT_INT64}, {"mapdefun_0", DT_INT64}},
+      {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+  NodeDef* map_defun =
+      AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64, DT_INT64},
+                      {{}, {}}, inner.signature().name(), &outer);
+  CHECK_NOTNULL(map_defun);
+
+  VectorizeMapDefun(&outer, &inner, map_defun);
+  EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+  const NodeDef& cast_node =
+      outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+  EXPECT_EQ(cast_node.input(0), "x");
+  EXPECT_EQ(outer.ret().at("mapdefun"),
+            strings::StrCat(cast_node.name(), ":y:0"));
+  EXPECT_EQ(outer.ret().at("mapdefun_0"),
+            strings::StrCat(cast_node.name(), ":y:0"));
+  EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+//                        +------+
+// +----------------------+ Arg0 +----------------------+
+// |                      +---+--+                      |
+// |                          |                         |
+// |                      +---v--+                      |
+// |   +------------------+ Arg0 +------------------+   |
+// |   |                  +---+--+                  |   |
+// |   |                      |                     |   |
+// |   |                      |                     |   |
+// |   |                  +---v---+ num=3           |   |
+// |   |                  |Unstack| axis=0          |   |
+// |   |                  ++--+--++                 |   |
+// |   |                   |  |  |                  |   |
+// |   |              +----+  |  +-------+          |   |
+// |   |              |       |          |          |   |
+// |   | MapDefun +---v--+  +-v----+  +--v---+      |   |
+// |   +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+   |
+// |              +---+--+  +--+---+  +--+---+          |
+// |                  |        |         |              |
+// |              +---v--+  +--v---+  +--v---+          |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+//                +------+  +------+  +------+
+//
+//
+//  After:
+//
+//                        +------+
+// +----------------------+ Arg0 +----------------------+
+// |                      +---+--+                      |
+// |                          |                         |
+// |                          |                         |
+// |                          |                         |
+// |                      +---v---+ num=3               |
+// |                      |Unstack| axis=1              |
+// |                      ++--+--++                     |
+// |                       |  |  |                      |
+// |                  +----+  |  +-------+              |
+// |                  |       |          |              |
+// |                  |       |          |              |
+// |              +---v--+  +-v----+  +--v---+          |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+//                +------+  +------+  +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
+  FunctionDef inner = CreateFunction(
+      "inner_function", {{"arg0", DT_INT32}},
+      {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+      {{"ret0", "MyUnstack:output:0"},
+       {"ret1", "MyUnstack:output:1"},
+       {"ret2", "MyUnstack:output:2"}});
+  NodeDef* unstack_op =
+      AddUnstackNode("MyUnstack", {"arg0"}, DT_INT32, 0, 3, &inner);
+  CHECK_NOTNULL(unstack_op);
+
+  FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+                                     {{"mapdefun", DT_INT32},
+                                      {"mapdefun_0", DT_INT32},
+                                      {"mapdefun_1", DT_INT32}},
+                                     {{"mapdefun", "MapDefun:output:0"},
+                                      {"mapdefun_0", "MapDefun:output:1"},
+                                      {"mapdefun_1", "MapDefun:output:2"}});
+
+  NodeDef* map_defun = AddMapDefunNode(
+      "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+      {{1}, {1}, {1}}, inner.signature().name(), &outer);
+  CHECK_NOTNULL(map_defun);
+
+  VectorizeMapDefun(&outer, &inner, map_defun);
+  EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+  const NodeDef& unpack_node =
+      outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+  EXPECT_EQ(unpack_node.input(0), "x");
+  EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+  EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+  EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+  EXPECT_EQ(outer.ret().at("mapdefun"),
+            strings::StrCat(unpack_node.name(), ":output:0"));
+  EXPECT_EQ(outer.ret().at("mapdefun_0"),
+            strings::StrCat(unpack_node.name(), ":output:1"));
+  EXPECT_EQ(outer.ret().at("mapdefun_1"),
+            strings::StrCat(unpack_node.name(), ":output:2"));
+  EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+//                        +------+
+// +----------------------+ Arg0 +----------------------+
+// |                      +---+--+                      |
+// |                          |                         |
+// |                      +---v--+                      |
+// |   +------------------+ Arg0 +------------------+   |
+// |   |                  +---+--+                  |   |
+// |   |                      |                     |   |
+// |   |                  +---+--+                  |   |
+// |   |                  | Cast |                  |   |
+// |   |                  +---+--+                  |   |
+// |   |                      |                     |   |
+// |   |                  +---v---+ num=3           |   |
+// |   |                  |Unstack| axis=0          |   |
+// |   |                  ++--+--++                 |   |
+// |   |                   |  |  |                  |   |
+// |   |              +----+  |  +-------+          |   |
+// |   |              |       |          |          |   |
+// |   | MapDefun +---v--+  +-v----+  +--v---+      |   |
+// |   +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+   |
+// |              +---+--+  +--+---+  +--+---+          |
+// |                  |        |         |              |
+// |              +---v--+  +--v---+  +--v---+          |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+//                +------+  +------+  +------+
+//
+//
+//  After:
+//
+//                        +------+
+// +----------------------+ Arg0 +----------------------+
+// |                      +---+--+                      |
+// |                          |                         |
+// |                      +---+--+                      |
+// |                      | Cast |                      |
+// |                      +---+--+                      |
+// |                          |                         |
+// |                      +---v---+ num=3               |
+// |                      |Unstack| axis=1              |
+// |                      ++--+--++                     |
+// |                       |  |  |                      |
+// |                  +----+  |  +-------+              |
+// |                  |       |          |              |
+// |                  |       |          |              |
+// |              +---v--+  +-v----+  +--v---+          |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+//                +------+  +------+  +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
+  FunctionDef inner = CreateFunction(
+      "inner_function", {{"arg0", DT_INT32}},
+      {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+      {{"ret0", "MyUnstack:output:0"},
+       {"ret1", "MyUnstack:output:1"},
+       {"ret2", "MyUnstack:output:2"}});
+  NodeDef* cast_op =
+      AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+  CHECK_NOTNULL(cast_op);
+  NodeDef* unstack_op =
+      AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
+  CHECK_NOTNULL(unstack_op);
+
+  FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+                                     {{"mapdefun", DT_INT32},
+                                      {"mapdefun_0", DT_INT32},
+                                      {"mapdefun_1", DT_INT32}},
+                                     {{"mapdefun", "MapDefun:output:0"},
+                                      {"mapdefun_0", "MapDefun:output:1"},
+                                      {"mapdefun_1", "MapDefun:output:2"}});
+
+  NodeDef* map_defun = AddMapDefunNode(
+      "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+      {{1}, {1}, {1}}, inner.signature().name(), &outer);
+  CHECK_NOTNULL(map_defun);
+
+  VectorizeMapDefun(&outer, &inner, map_defun);
+  EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+  const NodeDef& cast_node =
+      outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+  EXPECT_EQ(cast_node.input(0), "x");
+  const NodeDef& unpack_node =
+      outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+  EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
+  EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+  EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+  EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+
+  EXPECT_EQ(outer.ret().at("mapdefun"),
+            strings::StrCat(unpack_node.name(), ":output:0"));
+  EXPECT_EQ(outer.ret().at("mapdefun_0"),
+            strings::StrCat(unpack_node.name(), ":output:1"));
+  EXPECT_EQ(outer.ret().at("mapdefun_1"),
+            strings::StrCat(unpack_node.name(), ":output:2"));
+  EXPECT_EQ(outer.node_def_size(), 2);
+}
+
+// Before:
+//
+//
+//                 +------+
+// +---------------+ Arg0 +---------+
+// |               +---+--+         |
+// |                   |            |
+// |               +---v--+         |
+// |   +-----------+ Arg0 +-----+   |
+// |   |           +---+--+     |   |
+// |   |     +---------+        |   |
+// |   | +---v--+      |        |   |
+// |   | |Print |      |        |   |
+// |   | +---+--+      |        |   |
+// |   |     :     +---v--+     |   |
+// |   |     ::::::> Cast |     |   |
+// |   |           +---+--+     |   |
+// |   |               |        |   |
+// |   | MapDefun  +---v--+     |   |
+// |   +-----------+ Ret0 +-----+   |
+// |               +---+--+         |
+// |                   |            |
+// |               +---v--+         |
+// +---------------+ Ret0 +---------+
+//                 +------+
+//
+//
+//  After:
+//
+//  No change because we don't deal with control inputs for now.
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
+  FunctionDef inner =
+      CreateFunction("inner_function", {{"arg0", DT_INT32}},
+                     {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+  // The attrs aren't relevant
+  NodeDef* print_op =
+      function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner);
+  CHECK_NOTNULL(print_op);
+  NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
+                                 false, &inner);
+  CHECK_NOTNULL(cast_op);
+
+  FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+                                     {{"mapdefun", DT_INT64}},
+                                     {{"mapdefun", "MapDefun:output:0"}});
+
+  NodeDef* map_defun =
+      AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+                      inner.signature().name(), &outer);
+  CHECK_NOTNULL(map_defun);
+
+  FunctionDef outer_copy(outer);
+  FunctionDef inner_copy(inner);
+  VectorizeMapDefun(&outer, &inner, map_defun);
+  // They should be unchanged
+  EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+}
+
+// TODO(rachelim): More test cases when we get around to implementing them:
+// [] A badly defined converter, e.g. doesn't produce nodes that have the
+//    same number of outputs/inputs as the nodes to be converted
+// [] Converter where the 'converted' form has multiple nodes.
+// [] Case with dependent nodes, e.g. ops with const inputs that are
+//    broadcasted.
+// [] Python-side tests to actually run the functions to make sure
+//    they work.
+
+}  // namespace
+}  // namespace vectorization_utils
+}  // namespace grappler
+}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
index eeea269..2c36c9b 100644
--- a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
@@ -32,8 +32,6 @@
 namespace tensorflow {
 namespace grappler {
 
-REGISTER_GRAPH_OPTIMIZER(ExperimentalImplementationSelector);
-
 Status ExperimentalImplementationSelector::LoadFunctions(
     const GraphDef& graph) {
   lib_info_.reset(new FunctionLibraryApiInfo);
@@ -43,8 +41,20 @@
 
 Status ExperimentalImplementationSelector::MaybeOptimizeFunctionCall(
     NodeDef* node_def) const {
-  const FunctionApiInfo* info = lib_info_->GetApiInfo(node_def->op());
-  if (info == nullptr) {
+  // There are two ways of calling functions:
+  //  1. By specifying an op name as a function name, or
+  //  2. Via the @defun functional interface, where the real function name
+  //     appear as the attribute with type func.
+  std::vector<string> function_attribute_names;
+  for (const auto& attr : node_def->attr()) {
+    if (attr.second.has_func() &&
+        lib_info_->GetApiInfo(attr.second.func().name()) != nullptr) {
+      function_attribute_names.emplace_back(attr.first);
+    }
+  }
+
+  if (function_attribute_names.empty() &&
+      lib_info_->GetApiInfo(node_def->op()) == nullptr) {
     // A regular op, or a function which has no interface.
     return Status::OK();
   }
@@ -58,17 +68,25 @@
   DeviceNameUtils::ParsedName parsed_name;
   DeviceNameUtils::ParseLocalName(device, &parsed_name);
 
-  string best_function_name;
-  lib_info_->GetBestImplementation(node_def->op(), parsed_name.type,
-                                   &best_function_name);
-  if (node_def->op() != best_function_name) {
-    // The current implementation is not the best, swap the op to the best one.
-    // There will be duplicates in the graph and they will be pruned by other
-    // grappler plugin since no other node is using their output as inputs.
-    // TODO(scottzhu): Update the tf.eager.defun to register functions without
-    // having to call them with input data. That will reduce the graph size and
-    // save the work for prune them.
-    node_def->set_op(best_function_name);
+  for (const auto& attr_name : function_attribute_names) {
+    string function_name = node_def->attr().at(attr_name).func().name();
+    string best_function_name;
+    lib_info_->GetBestImplementation(function_name, parsed_name.type,
+                                     &best_function_name);
+    if (function_name != best_function_name) {
+      node_def->mutable_attr()
+          ->find(attr_name)
+          ->second.mutable_func()
+          ->set_name(best_function_name);
+    }
+  }
+  if (lib_info_->GetApiInfo(node_def->op()) != nullptr) {
+    string best_function_name;
+    lib_info_->GetBestImplementation(node_def->op(), parsed_name.type,
+                                     &best_function_name);
+    if (node_def->op() != best_function_name) {
+      node_def->set_op(best_function_name);
+    }
   }
   return Status::OK();
 }
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
index 2368e57..3f1ebef 100644
--- a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
@@ -45,9 +45,8 @@
   GrapplerItem item;
   CHECK(fake_input.NextItem(&item));
 
-  std::unique_ptr<CustomGraphOptimizer> optimizer =
-      CustomGraphOptimizerRegistry::CreateByNameOrNull(
-          "ExperimentalImplementationSelector");
+  std::unique_ptr<CustomGraphOptimizer> optimizer(
+      new ExperimentalImplementationSelector);
   ASSERT_NE(nullptr, optimizer);
   TF_ASSERT_OK(optimizer->Init());
 
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 8c99598..4b0cbfa 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
 #include "tensorflow/core/grappler/optimizers/debug_stripper.h"
 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
 #include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
@@ -72,6 +73,16 @@
          name == "loop_optimizer";
 }
 
+// Check if the graphdef contains nodes that indicate TPU execution.
+bool IsTPUGraphDef(const GraphDef& def) {
+  for (auto node : def.node()) {
+    if (node.op() == "TPUCompile" || node.op() == "TPUPartitionedCall") {
+      return true;
+    }
+  }
+  return false;
+}
+
 }  // namespace
 
 #define MK_OPT(NAME, VALUE) \
@@ -186,8 +197,18 @@
 Status MetaOptimizer::InitializeCustomGraphOptimizers(
     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
   for (const auto& optimizer_config : cfg_.custom_optimizers()) {
-    auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
-        optimizer_config.name());
+    // Initialize the ExperimentalImplementationSelector here instead of
+    // CustomizeOptimizer registry, due the static link issue in TensorRT for
+    // double registry.
+    // TODO(laigd): Remove this hack and change it back to use the registry once
+    // the duplicate static import issue is fixed.
+    std::unique_ptr<CustomGraphOptimizer> custom_optimizer;
+    if (optimizer_config.name() == "ExperimentalImplementationSelector") {
+      custom_optimizer.reset(new ExperimentalImplementationSelector());
+    } else {
+      custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
+          optimizer_config.name());
+    }
     if (custom_optimizer) {
       VLOG(2) << "Registered custom configurable graph optimizer: "
               << optimizer_config.name();
@@ -331,13 +352,26 @@
 
 Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
                                GraphDef* optimized_graph) {
-  LOG(INFO) << "Starting optimization for grappler item: " << item.id;
+  VLOG(1) << "Starting optimization for grappler item: " << item.id;
   optimization_results_.clear();
 
   // 1. Optimize main graph
   TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
   VLOG(1) << "Optimized main graph.";
 
+  // Skip optimizing functions if this is a TPU graph. Currently, Grappler
+  // passes do not handle TPU functions correctly in a variety of ways (Note
+  // that due to the pre-placement TPU graph rewriting passes, the TPU-related
+  // ops are encapsulated away into functions). For example, TPU graphs contain
+  // TPUReplicateMetadata node that carries relevant TPU metadata and Grappler
+  // passes could prune that away. Grappler passes could also cause issues
+  // around shape inference. Since the desired and existing behavior is to not
+  // optimize TPU functions with Grappler, this check preserves that.
+  if (IsTPUGraphDef(*optimized_graph)) {
+    VLOG(2) << "Skipping optimizing funcs for TPU graphs";
+    return Status::OK();
+  }
+
   // 2. Optimize function library
   FunctionLibraryDefinition flib(OpRegistry::Global(),
                                  optimized_graph->library());
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 03e36a7..008a289 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -218,7 +218,7 @@
 void Remapper::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/,
                         const GraphDef& /*optimized_graph*/,
                         double /*result*/) {
-  // Nothing to do for ArithmeticOptimizer.
+  // Nothing to do for RemapperOptimizer.
 }
 
 }  // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
index caa0b7b..4542d17 100644
--- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -20,10 +20,9 @@
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/grappler/graph_view.h"
 #include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
-
 #include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
 #include "tensorflow/core/lib/core/errors.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index e540cc0..bdbb883 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -1,6 +1,10 @@
 licenses(["notice"])  # Apache 2.0
 
 load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+    "//tensorflow/core:platform/default/build_config.bzl",
+    "tf_protos_grappler",
+)
 
 cc_library(
     name = "scc",
@@ -210,3 +214,28 @@
         "//tensorflow/core:testlib",
     ],
 )
+
+cc_library(
+    name = "symbolic_shapes",
+    srcs = ["symbolic_shapes.cc"],
+    hdrs = ["symbolic_shapes.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+    ] + tf_protos_grappler(),
+)
+
+tf_cc_test(
+    name = "symbolic_shapes_test",
+    srcs = ["symbolic_shapes_test.cc"],
+    deps = [
+        ":symbolic_shapes",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+    ],
+)
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc b/tensorflow/core/grappler/utils/symbolic_shapes.cc
similarity index 98%
rename from tensorflow/core/grappler/optimizers/symbolic_shapes.cc
rename to tensorflow/core/grappler/utils/symbolic_shapes.cc
index 155843a..1666de4 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc
+++ b/tensorflow/core/grappler/utils/symbolic_shapes.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
 #include "tensorflow/core/util/bcast.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.h b/tensorflow/core/grappler/utils/symbolic_shapes.h
similarity index 94%
rename from tensorflow/core/grappler/optimizers/symbolic_shapes.h
rename to tensorflow/core/grappler/utils/symbolic_shapes.h
index ace7bd1..0a7d8ac 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.h
+++ b/tensorflow/core/grappler/utils/symbolic_shapes.h
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
+#define TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
 
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
@@ -74,4 +74,4 @@
 }  // namespace grappler
 }  // end namespace tensorflow
 
-#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
+#endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc
similarity index 98%
rename from tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc
rename to tensorflow/core/grappler/utils/symbolic_shapes_test.cc
index 7ce995d..6ac644c 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc
+++ b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/platform/test.h"
 
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 94d3ab4..08245e6 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -30,6 +30,7 @@
     "//tensorflow:tensorflow.bzl",
     "if_android",
     "tf_cc_test",
+    "tf_cc_test_mkl",
     "tf_cc_tests",
     "tf_cc_binary",
     "tf_copts",
@@ -50,6 +51,10 @@
     "tf_kernel_tests_linkstatic",
 )
 load(
+    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "tf_cuda_tests_tags",
+)
+load(
     "//third_party/mkl:build_defs.bzl",
     "if_mkl",
     "if_mkl_ml",
@@ -636,6 +641,7 @@
         ":reshape_op",
         ":reverse_op",
         ":reverse_sequence_op",
+        ":searchsorted_op",
         ":shape_ops",
         ":slice_op",
         ":snapshot_op",
@@ -869,6 +875,12 @@
 )
 
 tf_kernel_library(
+    name = "searchsorted_op",
+    prefix = "searchsorted_op",
+    deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
     name = "inplace_ops",
     prefix = "inplace_ops",
     deps = ARRAY_DEPS,
@@ -1105,7 +1117,7 @@
     name = "depthwise_conv_ops_test",
     size = "small",
     srcs = ["depthwise_conv_ops_test.cc"],
-    tags = ["requires-gpu-sm35"],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":conv_ops",
         ":image",
@@ -2702,6 +2714,7 @@
 )
 
 LOGGING_DEPS = [
+    "@com_google_absl//absl/strings",
     "//tensorflow/core:framework",
     "//tensorflow/core:lib",
     "//tensorflow/core:lib_internal",
@@ -2759,6 +2772,7 @@
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -4396,6 +4410,7 @@
         ":reduce_join_op",
         ":regex_full_match_op",
         ":regex_replace_op",
+        ":string_format_op",
         ":string_join_op",
         ":string_length_op",
         ":string_split_op",
@@ -4427,6 +4442,30 @@
 )
 
 tf_kernel_library(
+    name = "string_format_op",
+    prefix = "string_format_op",
+    deps = STRING_DEPS + ["@com_google_absl//absl/strings"],
+)
+
+tf_cc_test(
+    name = "string_format_op_test",
+    size = "small",
+    srcs = ["string_format_op_test.cc"],
+    deps = [
+        ":string_format_op",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/kernels:ops_testutil",
+        "//tensorflow/core/kernels:ops_util",
+    ],
+)
+
+tf_kernel_library(
     name = "string_join_op",
     prefix = "string_join_op",
     deps = STRING_DEPS,
@@ -6228,6 +6267,26 @@
     ] + mkl_deps(),
 )
 
+tf_cc_test_mkl(
+    name = "mkl_conv_ops_test",
+    size = "small",
+    srcs = ["mkl_conv_ops_test.cc"],
+    deps = [
+        ":ops_testutil",
+        ":ops_util",
+        "//tensorflow/cc:cc_ops",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:tensorflow",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
 tf_mkl_kernel_library(
     name = "mkl_tfconv_op",
     prefix = "mkl_tfconv",
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
index 7b28c8e..e15ea82 100644
--- a/tensorflow/core/kernels/bias_op.cc
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -134,8 +134,8 @@
     if (data_format_ == FORMAT_NCHW) {
       int32 batch, height, width, channel;
       GetBiasValueDims(input, data_format_, &batch, &height, &width, &channel);
-      Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1);
-      Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width);
+      Eigen::DSizes<Eigen::Index, 4> four_dims(1, channel, 1, 1);
+      Eigen::DSizes<Eigen::Index, 4> broad_cast_dims(batch, 1, height, width);
       const Device& d = context->eigen_device<Device>();
       output->tensor<T, 4>().device(d) =
           input.tensor<T, 4>() +
@@ -247,14 +247,14 @@
         OP_REQUIRES(context, output_backprop.dims() == 4,
                     errors::InvalidArgument(
                         "NCHW format supports only 4D input/output tensor."));
-        Eigen::DSizes<int, 4> four_dims(batch, channel, height, width);
+        Eigen::DSizes<Eigen::Index, 4> four_dims(batch, channel, height, width);
 #ifdef EIGEN_HAS_INDEX_LIST
         using idx0 = Eigen::type2index<0>;
         using idx2 = Eigen::type2index<2>;
         using idx3 = Eigen::type2index<3>;
         Eigen::IndexList<idx0, idx2, idx3> reduction_axes;
 #else
-        Eigen::array<int, 3> reduction_axes = {0, 2, 3};
+        Eigen::array<Eigen::Index, 3> reduction_axes = {0, 2, 3};
 #endif
         output->template flat<T>().device(context->eigen_device<Device>()) =
             output_backprop.flat<T>()
@@ -263,11 +263,12 @@
                 .sum(reduction_axes)
                 .template cast<T>();  // End of code by intel_tf.
       } else {
-        Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
+        Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width,
+                                                channel);
 #ifdef EIGEN_HAS_INDEX_LIST
         Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
 #else
-        Eigen::array<int, 1> reduction_axis = {0};
+        Eigen::array<Eigen::Index, 1> reduction_axis = {0};
 #endif
         output->template flat<T>().device(context->eigen_device<Device>()) =
             output_backprop.flat<T>()
diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
index 6074b3e..7d09e9b 100644
--- a/tensorflow/core/kernels/bincount_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
@@ -17,7 +17,7 @@
 
 #define EIGEN_USE_GPU
 
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
index b2efa06..4ae26fb 100644
--- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
@@ -334,30 +334,34 @@
         // Proto to store debug outputs, per example.
         boosted_trees::DebugOutput example_debug_info;
         // Initial bias prediction. E.g., prediction based off training mean.
-        example_debug_info.add_logits_path(resource->GetTreeWeight(0) *
-                                           resource->node_value(0, 0));
+        float tree_logit =
+            resource->GetTreeWeight(0) * resource->node_value(0, 0);
+        example_debug_info.add_logits_path(tree_logit);
         int32 node_id = 0;
         int32 tree_id = 0;
         int32 feature_id;
-        float tree_logit;
         float past_trees_logit = 0;  // Sum of leaf logits from prior trees.
-        // Populate proto.
+        // Go through each tree and populate proto.
         while (tree_id <= last_tree) {
-          // Feature id used to split.
-          feature_id = resource->feature_id(tree_id, node_id);
-          example_debug_info.add_feature_ids(feature_id);
-          // Get logit after split.
-          node_id = resource->next_node(tree_id, node_id, i,
-                                        batch_bucketized_features);
-          tree_logit = resource->GetTreeWeight(tree_id) *
-                       resource->node_value(tree_id, node_id);
-          // Output logit incorporates sum of leaf logits from prior trees.
-          example_debug_info.add_logits_path(tree_logit + past_trees_logit);
-          if (resource->is_leaf(tree_id, node_id)) {
-            // Move onto other trees.
-            past_trees_logit += tree_logit;
+          if (resource->is_leaf(tree_id, node_id)) {  // Move onto other trees.
+            // Accumulate tree_logits only if the leaf is non-root, but do so
+            // for bias tree.
+            if (tree_id == 0 || node_id > 0) {
+              past_trees_logit += tree_logit;
+            }
             ++tree_id;
             node_id = 0;
+          } else {  // Add to proto.
+            // Feature id used to split.
+            feature_id = resource->feature_id(tree_id, node_id);
+            example_debug_info.add_feature_ids(feature_id);
+            // Get logit after split.
+            node_id = resource->next_node(tree_id, node_id, i,
+                                          batch_bucketized_features);
+            tree_logit = resource->GetTreeWeight(tree_id) *
+                         resource->node_value(tree_id, node_id);
+            // Output logit incorporates sum of leaf logits from prior trees.
+            example_debug_info.add_logits_path(tree_logit + past_trees_logit);
           }
         }
         // Set output as serialized proto containing debug info.
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index de9b698..639c306 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -137,17 +137,16 @@
   }
 };
 
-// Shuffles a filter tensor from:
-//   [<spatial_dims>, in, out]
-// to:
-//   [out, in, <spatial_dims>]
+// Shuffles a filter tensor from TensorFlow format HWIO to dst_filter_format.
+//
+// Note: Currently OIHW is the only supported destination format. Support for
+// OHWI format will be added in a follow-up change.
 template <typename Device, typename T, typename IndexType, int NDIMS>
 struct TransformFilter {
-  void operator()(const Device& d,
+  void operator()(const Device& d, FilterTensorFormat dst_filter_format,
                   typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
                   typename TTypes<T, NDIMS, IndexType>::Tensor out) {
-    // We want a 3, 2, 0, 1 shuffle. Merge the spatial dimensions together
-    // to speed up the shuffle operation.
+    // Merge the spatial dimensions together to speed up the shuffle operation.
     Eigen::DSizes<IndexType, 3> merged_dims;
     merged_dims[0] = in.dimension(0);  // spatial dimensions
     for (int i = 1; i < NDIMS - 2; ++i) {
@@ -156,16 +155,30 @@
     merged_dims[1] = in.dimension(NDIMS - 2);  // input filters
     merged_dims[2] = in.dimension(NDIMS - 1);  // output filters
 
+    CHECK(dst_filter_format == FORMAT_OIHW)
+        << "Unsupported destination filter format: "
+        << ToString(dst_filter_format);
+    // Source filter format is FORMAT_HWIO and spatial dimensions HW are merged
+    // in the beginning.
+    Eigen::DSizes<IndexType, 3> shuffling_perm =
+        Eigen::DSizes<IndexType, 3>(2, 1, 0);
+
     Eigen::DSizes<IndexType, NDIMS> expanded_dims;
-    expanded_dims[0] = in.dimension(NDIMS - 1);  // output filters
-    expanded_dims[1] = in.dimension(NDIMS - 2);  // input filters
-    for (int i = 0; i < NDIMS - 2; ++i) {        // spatial dimensions
-      expanded_dims[i + 2] = in.dimension(i);
+    int out_index = 0;
+    for (int merged_dim = 0; merged_dim < merged_dims.rank(); ++merged_dim) {
+      if (shuffling_perm[merged_dim] == 0) {
+        for (int spatial_dim = 0; spatial_dim < NDIMS - 2; ++spatial_dim) {
+          expanded_dims[out_index++] = in.dimension(spatial_dim);
+        }
+      } else {
+        constexpr int kLastSpatialDim = NDIMS - 3;
+        expanded_dims[out_index++] =
+            in.dimension(kLastSpatialDim + shuffling_perm[merged_dim]);
+      }
     }
 
-    out.device(d) = in.reshape(merged_dims)
-                        .shuffle(Eigen::DSizes<IndexType, 3>(2, 1, 0))
-                        .reshape(expanded_dims);
+    out.device(d) =
+        in.reshape(merged_dims).shuffle(shuffling_perm).reshape(expanded_dims);
   }
 };
 
@@ -282,7 +295,9 @@
                   const gtl::ArraySlice<int64>& input_dims, T* out);
 };
 
-// Reverses the effect of TransformFilter above.
+// Transforms back filter from OIHW to HWOI format to reverse effect of
+// TransformFilter above.
+// TODO(hinsu): Support reverse transformation from filter format OHWI as well.
 template <typename Device, typename T, int NDIMS>
 struct ReverseTransformFilter {
   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index 63b1bcd..9e86a16 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -1018,7 +1018,8 @@
   extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>;     \
   template <>                                                            \
   void TransformFilter<GPUDevice, T, int, 4>::operator()(                \
-      const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,    \
+      const GPUDevice& d, FilterTensorFormat dst_filter_format,          \
+      typename TTypes<T, 4, int>::ConstTensor in,                        \
       typename TTypes<T, 4, int>::Tensor out);                           \
   extern template struct TransformFilter<GPUDevice, T, int, 4>;          \
   template <>                                                            \
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index d664a11..43bb5ea 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -901,7 +901,8 @@
                               &transformed_filter));
 
   functor::TransformFilter<GPUDevice, T, int, 4>()(
-      ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+      ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+      To32Bit(filter.tensor<T, 4>()),
       To32Bit(transformed_filter.tensor<T, 4>()));
 
   Tensor transformed_out_backprop;
@@ -1090,7 +1091,8 @@
   extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>;     \
   template <>                                                            \
   void TransformFilter<GPUDevice, T, int, 4>::operator()(                \
-      const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,    \
+      const GPUDevice& d, FilterTensorFormat dst_filter_format,          \
+      typename TTypes<T, 4, int>::ConstTensor in,                        \
       typename TTypes<T, 4, int>::Tensor out);                           \
   extern template struct TransformFilter<GPUDevice, T, int, 4>;          \
   template <>                                                            \
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index d26b86c..bab91f5 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -1054,7 +1054,8 @@
 #define DECLARE_GPU_SPEC(T)                                           \
   template <>                                                         \
   void TransformFilter<GPUDevice, T, int, 5>::operator()(             \
-      const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+      const GPUDevice& d, FilterTensorFormat dst_filter_format,       \
+      typename TTypes<T, 5, int>::ConstTensor in,                     \
       typename TTypes<T, 5, int>::Tensor out);                        \
   template <>                                                         \
   void ReverseTransformFilter<GPUDevice, T, 5>::operator()(           \
@@ -1287,7 +1288,8 @@
                          dims.filter_size(1), dims.filter_size(2)}),
             &transformed_filter));
     functor::TransformFilter<GPUDevice, T, int, 5>()(
-        context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
+        context->eigen_device<GPUDevice>(), FORMAT_OIHW,
+        To32Bit(filter.tensor<T, 5>()),
         To32Bit(transformed_filter.tensor<T, 5>()));
 
     // Shape: batch, filters, z, y, x.
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index ef69241..717a9f4 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -680,9 +680,9 @@
                           TensorShape({filter.dim_size(3), filter.dim_size(2),
                                        filter.dim_size(0), filter.dim_size(1)}),
                           &transformed_filter));
-
   functor::TransformFilter<GPUDevice, T, int, 4>()(
-      ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+      ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+      To32Bit(filter.tensor<T, 4>()),
       To32Bit(transformed_filter.tensor<T, 4>()));
 
   Tensor transformed_output;
@@ -731,9 +731,15 @@
   if (cudnn_use_autotune &&
       !AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
     std::vector<AlgorithmDesc> algorithms;
-    CHECK(stream->parent()->GetConvolveAlgorithms(
-        conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
-        &algorithms));
+    OP_REQUIRES(
+        ctx,
+        stream->parent()->GetConvolveAlgorithms(
+            conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
+                stream->parent()),
+            &algorithms),
+        errors::Unknown("Failed to get convolution algorithm. This is probably "
+                        "because cuDNN failed to initialize, so try looking to "
+                        "see if a warning log message was printed above."));
     ProfileResult best_result;
     ProfileResult best_result_no_scratch;
     for (auto profile_algorithm : algorithms) {
@@ -823,7 +829,8 @@
   extern template struct MatMulConvFunctor<GPUDevice, T>;                    \
   template <>                                                                \
   void TransformFilter<GPUDevice, T, int, 4>::operator()(                    \
-      const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,        \
+      const GPUDevice& d, FilterTensorFormat dst_filter_format,              \
+      typename TTypes<T, 4, int>::ConstTensor in,                            \
       typename TTypes<T, 4, int>::Tensor out);                               \
   extern template struct TransformFilter<GPUDevice, T, int, 4>;              \
   template <>                                                                \
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index a1eed4e..83df4dc 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -386,7 +386,8 @@
     // filter: [x, y, z, in, out]
     // t_filter: [out, in, x, y, z]
     functor::TransformFilter<GPUDevice, T, int, 5>()(
-        ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
+        ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+        To32Bit(filter.tensor<T, 5>()),
         To32Bit(transformed_filter.tensor<T, 5>()));
 
     Tensor transformed_output;
@@ -434,10 +435,16 @@
     if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
                                   conv_parameters, &algorithm_config)) {
       std::vector<AlgorithmDesc> algorithms;
-      CHECK(stream->parent()->GetConvolveAlgorithms(
-          conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
-              stream->parent()),
-          &algorithms));
+      OP_REQUIRES(ctx,
+                  stream->parent()->GetConvolveAlgorithms(
+                      conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
+                          stream->parent()),
+                      &algorithms),
+                  errors::Unknown(
+                      "Failed to get convolution algorithm. This is probably "
+                      "because cuDNN failed to initialize, so try looking to "
+                      "see if a warning log message was printed above."));
+
       ProfileResult best_result;
       ProfileResult best_result_no_scratch;
       for (auto profile_algorithm : algorithms) {
@@ -514,7 +521,8 @@
 #define DECLARE_GPU_SPEC(T)                                           \
   template <>                                                         \
   void TransformFilter<GPUDevice, T, int, 5>::operator()(             \
-      const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+      const GPUDevice& d, FilterTensorFormat dst_filter_format,       \
+      typename TTypes<T, 5, int>::ConstTensor in,                     \
       typename TTypes<T, 5, int>::Tensor out);                        \
   template <>                                                         \
   void ReverseTransformFilter<GPUDevice, T, 5>::operator()(           \
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index afc611f..21d135d 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -142,8 +142,12 @@
   template <typename T>
   bool ShouldIncludeWinogradNonfusedAlgo(
       se::StreamExecutor* stream_exec) const {
+    auto* dnn_support = stream_exec->AsDnn();
+    if (!dnn_support) {
+      return false;
+    }
     // Skip this check for cuDNN 7 and newer.
-    auto version = stream_exec->AsDnn()->GetVersion();
+    auto version = dnn_support->GetVersion();
     if (version.ok() && version.ValueOrDie().major_version() >= 7) {
       return true;
     }
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index a5fa48f..46167db 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -170,51 +170,33 @@
   return tensor_index;
 }
 
-// A Cuda custom kernel that swaps dimension-0 and dimension-2 of a 3D tensor.
-template <typename T, bool conjugate = false>
-__global__ void SwapDimension0And2InTensor3Simple(int nthreads, const T* input,
-                                                  Dimension<3> input_dims,
-                                                  T* output) {
+// A simple CUDA custom kernel to shuffle dimensions of a 3D tensor according to
+// the given shuffle permutation in template parameters. Shuffle permutation
+// <sp0, sp1, sp2> shuffles dimensions such that input dimension 0 goes to sp0,
+// 1 goes to sp1 and 2 goes to sp2. For example, shuffle permutation <2, 0, 1>
+// will populate output so that input[x][y][z] is equal to (*output)[y][z][x].
+//
+// Requires that nthreads is equal to the total number of elements in the input
+// tensor.
+template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
+__global__ void ShuffleInTensor3Simple(int nthreads, const T* input,
+                                       Dimension<3> input_dims, T* output) {
   Dimension<3> output_dims;
-  output_dims[0] = input_dims[2];
-  output_dims[1] = input_dims[1];
-  output_dims[2] = input_dims[0];
+  output_dims[sp0] = input_dims[0];
+  output_dims[sp1] = input_dims[1];
+  output_dims[sp2] = input_dims[2];
 
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int output_index = index;
-
+  // Iterate over output as opposed to iterating over input for better
+  // performance. Iterating over output will generate sequential writes and
+  // random reads that performs better compared to sequential reads and random
+  // writes.
+  CUDA_1D_KERNEL_LOOP(output_index, nthreads) {
     Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
 
     Index<3> input_tensor_index;
-    input_tensor_index[0] = output_tensor_index[2];
-    input_tensor_index[1] = output_tensor_index[1];
-    input_tensor_index[2] = output_tensor_index[0];
-
-    int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
-
-    output[output_index] =
-        maybe_conj<T, conjugate>::run(ldg(input + input_index));
-  }
-}
-
-// A Cuda custom kernel that swaps dimension-1 and dimension-2 of a 3D tensor.
-template <typename T, bool conjugate = false>
-__global__ void SwapDimension1And2InTensor3Simple(int nthreads, const T* input,
-                                                  Dimension<3> input_dims,
-                                                  T* output) {
-  Dimension<3> output_dims;
-  output_dims[0] = input_dims[0];
-  output_dims[1] = input_dims[2];
-  output_dims[2] = input_dims[1];
-
-  CUDA_1D_KERNEL_LOOP(index, nthreads) {
-    int output_index = index;
-    Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
-
-    Index<3> input_tensor_index;
-    input_tensor_index[0] = output_tensor_index[0];
-    input_tensor_index[1] = output_tensor_index[2];
-    input_tensor_index[2] = output_tensor_index[1];
+    input_tensor_index[0] = output_tensor_index[sp0];
+    input_tensor_index[1] = output_tensor_index[sp1];
+    input_tensor_index[2] = output_tensor_index[sp2];
 
     int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
 
@@ -439,7 +421,7 @@
 template <typename T, int NDIMS>
 struct TransformFilter<GPUDevice, T, int, NDIMS> {
   typedef GPUDevice Device;
-  void operator()(const Device& d,
+  void operator()(const Device& d, FilterTensorFormat dst_filter_format,
                   typename TTypes<T, NDIMS, int>::ConstTensor in,
                   typename TTypes<T, NDIMS, int>::Tensor out) {
     Dimension<3> combined_dims;
@@ -450,13 +432,18 @@
     combined_dims[1] = in.dimension(NDIMS - 2);  // input filters
     combined_dims[2] = in.dimension(NDIMS - 1);  // output filters
     CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
-    SwapDimension0And2InTensor3Simple<T>
+
+    CHECK(dst_filter_format == FORMAT_OIHW)
+        << "Unsupported output layout: " << ToString(dst_filter_format);
+
+    ShuffleInTensor3Simple<T, 2, 1, 0>
         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
             config.virtual_thread_count, in.data(), combined_dims, out.data());
   }
 };
 
-// Converts Cudnn filter format back to TensorFlow filter format.
+// Converts Cudnn filter format OIHW back to TensorFlow filter format HWIO.
+// TODO(hinsu): Support reverse transformation from filter format OHWI as well.
 template <typename T, int NDIMS>
 struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
   typedef GPUDevice Device;
@@ -470,7 +457,7 @@
       combined_dims[2] *= in.dimension(i);
     }
     CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
-    SwapDimension0And2InTensor3Simple<T>
+    ShuffleInTensor3Simple<T, 2, 1, 0>
         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
             config.virtual_thread_count, in.data(), combined_dims, out.data());
   }
@@ -937,7 +924,7 @@
   } else {
     int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
     CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
-    SwapDimension1And2InTensor3Simple<T, conjugate>
+    ShuffleInTensor3Simple<T, 0, 2, 1, conjugate>
         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
             config.virtual_thread_count, input, input_dims, output);
   }
@@ -969,7 +956,7 @@
                                static_cast<int>(combined_dims[2])};
     size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
     CudaLaunchConfig config = GetCudaLaunchConfig(total_size, d);
-    SwapDimension0And2InTensor3Simple<T, conjugate>
+    ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>
         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
             config.virtual_thread_count, in, input_dims, out);
   }
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index 887b8c8..d1db1d7 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -117,7 +117,7 @@
           : DatasetIterator<Dataset>(params) {}
 
       Status Initialize(IteratorContext* ctx) override {
-        SetMetadata(ctx, "batch_size", dataset()->batch_size_);
+        AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
         return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
       }
 
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index 31c8f5c..0bb929b 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -22,39 +22,96 @@
 #include "tensorflow/core/lib/gtl/optional.h"
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/util/ptr_util.h"
 
 namespace tensorflow {
 namespace data {
 
-/* static */
-Status CapturedFunction::Create(
-    const NameAttrList& func, std::vector<Tensor> captured_inputs,
-    std::unique_ptr<CapturedFunction>* out_function) {
-  return Create(func, std::move(captured_inputs), true, out_function);
-}
+namespace {
 
-/* static */
-Status CapturedFunction::Create(
-    const NameAttrList& func, std::vector<Tensor> captured_inputs,
-    bool use_inter_op_parallelism,
-    std::unique_ptr<CapturedFunction>* out_function) {
-  out_function->reset(new CapturedFunction(func, std::move(captured_inputs),
-                                           use_inter_op_parallelism));
-  return Status::OK();
-}
+// Simplistic implementation of the `StepStatsCollectorInterface` that only
+// cares about collecting the CPU time needed to execute a captured function.
+class SimpleStepStatsCollector : public StepStatsCollectorInterface {
+ public:
+  void IncrementProcessingTime(int64 delta) {
+    mutex_lock l(mu_);
+    processing_time_ += delta;
+  }
+
+  NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override {
+    return new SimpleNodeExecStats(this);
+  }
+
+  string ReportAllocsOnResourceExhausted(const string& err) override {
+    return "";
+  }
+
+  int64 processing_time() {
+    tf_shared_lock l(mu_);
+    return processing_time_;
+  }
+
+ private:
+  class SimpleNodeExecStats : public NodeExecStatsInterface {
+   public:
+    explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
+        : step_stats_collector_(step_stats_collector) {}
+
+    void Done(const string& device) override {
+      step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
+                                                     start_time_ns_);
+      delete this;
+    }
+
+    void RecordExecutorStarted() override {
+      start_time_ns_ = Env::Default()->NowNanos();
+    }
+
+    void RecordComputeStarted() override {}
+
+    void RecordComputeEnded() override {}
+
+    void RecordExecutorEnded() override {
+      end_time_ns_ = Env::Default()->NowNanos();
+    }
+
+    void SetMemory(OpKernelContext* ctx) override {}
+
+    void SetOutput(int slot, const Tensor* tensor) override {}
+
+    void SetReferencedTensors(const TensorReferenceVector& tensors) override {}
+
+    void SetScheduled(int64 nanos) override {}
+
+   private:
+    int64 start_time_ns_ = 0;
+    int64 end_time_ns_ = 0;
+    SimpleStepStatsCollector* step_stats_collector_;  // Not owned.
+  };
+
+  mutex mu_;
+  int64 processing_time_ GUARDED_BY(mu_) = 0;
+};
+
+}  // namespace
 
 /* static */
 Status CapturedFunction::Create(
     const NameAttrList& func, OpKernelContext* ctx, const string& argument,
     std::unique_ptr<CapturedFunction>* out_function) {
-  OpInputList argument_inputs;
-  TF_RETURN_IF_ERROR(ctx->input_list(argument, &argument_inputs));
-  std::vector<Tensor> arguments_t;
-  arguments_t.reserve(argument_inputs.size());
-  for (const Tensor& t : argument_inputs) {
-    arguments_t.push_back(t);
-  }
-  return CapturedFunction::Create(func, std::move(arguments_t), out_function);
+  return CapturedFunction::Create(func, ctx, argument, true, out_function);
+}
+
+Status CapturedFunction::Create(
+    const NameAttrList& func, OpKernelContext* ctx, const string& argument,
+    bool use_inter_op_parallelism,
+    std::unique_ptr<CapturedFunction>* out_function) {
+  OpInputList inputs;
+  TF_RETURN_IF_ERROR(ctx->input_list(argument, &inputs));
+  std::vector<Tensor> arguments(inputs.begin(), inputs.end());
+  *out_function = WrapUnique(new CapturedFunction(func, std::move(arguments),
+                                                  use_inter_op_parallelism));
+  return Status::OK();
 }
 
 CapturedFunction::~CapturedFunction() {
@@ -370,13 +427,13 @@
     done(s);
     return;
   }
-  auto frame =
+  OwnedArgsCallFrame* frame =
       new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_);
 
   FunctionLibraryRuntime::Options f_opts;
   f_opts.step_id = CapturedFunction::generate_step_id();
   ResourceMgr* resource_mgr = ctx->lib()->device()->resource_manager();
-  auto step_container = new ScopedStepContainer(
+  ScopedStepContainer* step_container = new ScopedStepContainer(
       f_opts.step_id, [resource_mgr](const string& name) {
         resource_mgr->Cleanup(name).IgnoreError();
       });
@@ -391,24 +448,19 @@
   // (such as queue kernels) that depend on the non-nullness of
   // `OpKernelContext::cancellation_manager()`, but additional effort
   // will be required to plumb it through the `IteratorContext`.
-  auto c_mgr = new CancellationManager;
+  CancellationManager* c_mgr = new CancellationManager;
   f_opts.cancellation_manager = c_mgr;
-  StepStats* stats = nullptr;
-  StepStatsCollector* stats_collector = nullptr;
-  std::shared_ptr<model::Node> node;
+  std::shared_ptr<SimpleStepStatsCollector> stats_collector;
   if (ctx->model()) {
-    node = ctx->model()->LookupNode(prefix);
-    if (node) {
-      // TODO(b/114104975): Use something light-weight here.
-      stats = new StepStats();
-      stats_collector = new StepStatsCollector(stats);
-    }
+    stats_collector = MakeUnique<SimpleStepStatsCollector>();
   }
-  f_opts.stats_collector = stats_collector;
+  f_opts.stats_collector = stats_collector.get();
 
   auto callback = std::bind(
-      [rets, step_container, c_mgr, frame, stats, stats_collector, node](
-          FunctionLibraryRuntime::DoneCallback done,
+      [rets, step_container, c_mgr, frame](
+          const FunctionLibraryRuntime::DoneCallback& done,
+          const std::shared_ptr<model::Model>& model, const string& prefix,
+          const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
           // Begin unbound arguments.
           Status s) {
         delete step_container;
@@ -417,25 +469,17 @@
           s = frame->ConsumeRetvals(rets);
         }
         delete frame;
-        if (node) {
-          int64 delta = 0;
-          stats_collector->Finalize();
-          for (auto dev_stats : stats->dev_stats()) {
-            for (auto node_stats : dev_stats.node_stats()) {
-              delta += node_stats.all_end_rel_nanos();
-            }
-          }
-          delete stats_collector;
-          delete stats;
-          node->add_processing_time(delta);
-          node->start_work();
+        if (model) {
+          model->AddProcessingTime(prefix, stats_collector->processing_time());
+          model->RecordStart(prefix, false /* stop_output */);
         }
         done(s);
-        if (node) {
-          node->stop_work();
+        if (model) {
+          model->RecordStop(prefix, false /* start_output */);
         }
       },
-      std::move(done), std::placeholders::_1);
+      std::move(done), ctx->model(), prefix, std::move(stats_collector),
+      std::placeholders::_1);
 
   ctx->lib()->Run(f_opts, handle, frame, std::move(callback));
 }
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index 8b420fa..a10376b 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -42,29 +42,21 @@
 // context.
 class CapturedFunction {
  public:
-  // Creates a new instance from a list of named attributes and captured inputs.
-  //
-  // NOTE(mrry): The `captured_inputs` are passed by value. For
-  // efficiency, you are recommended to move this argument into the call.
-  static Status Create(const NameAttrList& func,
-                       std::vector<Tensor> captured_inputs,
-                       std::unique_ptr<CapturedFunction>* out_function);
-
-  // Creates a new instance from a list of named attributes and captured inputs.
-  //
-  // If `use_inter_op_parallelism` is false, the runtime may use an executor
-  // that is optimized for small functions.
-  static Status Create(const NameAttrList& func,
-                       std::vector<Tensor> captured_inputs,
-                       bool use_inter_op_parallelism,
-                       std::unique_ptr<CapturedFunction>* out_function);
-
   // Creates a new instance using a list of named attributes, fetching captured
   // inputs from a context argument.
   static Status Create(const NameAttrList& func, OpKernelContext* ctx,
                        const string& argument,
                        std::unique_ptr<CapturedFunction>* out_function);
 
+  // Creates a new instance using a list of named attributes, fetching captured
+  // inputs from a context argument.
+  //
+  // If `use_inter_op_parallelism` is false, the runtime may use an executor
+  // that is optimized for small functions.
+  static Status Create(const NameAttrList& func, OpKernelContext* ctx,
+                       const string& argument, bool use_inter_op_parallelism,
+                       std::unique_ptr<CapturedFunction>* out_function);
+
   ~CapturedFunction();
 
   // Runs the "Captured function" using the given FLR and caches the lib and
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index bf0aeca..0088431 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -14,11 +14,13 @@
 ==============================================================================*/
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/kernels/data/captured_function.h"
 #include "tensorflow/core/kernels/data/dataset.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/str_util.h"
 
 namespace tensorflow {
 namespace data {
@@ -37,14 +39,6 @@
 
   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                    DatasetBase** output) override {
-    OpInputList inputs;
-    OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
-    std::vector<Tensor> other_arguments;
-    other_arguments.reserve(inputs.size());
-    for (const Tensor& t : inputs) {
-      other_arguments.push_back(t);
-    }
-
     FunctionLibraryRuntime::Handle pred_handle;
     OP_REQUIRES_OK(ctx,
                    ctx->function_library()->Instantiate(
@@ -61,9 +55,10 @@
     Node* ret_node = pred_body->ret_nodes[0];
     Node* ret_input_node;
     OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node));
+
     std::unique_ptr<CapturedFunction> captured_func;
-    OP_REQUIRES_OK(ctx, CapturedFunction::Create(
-                            func_, std::move(other_arguments), &captured_func));
+    OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+                                                 &captured_func));
 
     if (ret_input_node->def().op() == "_Arg") {
       int32 index = -1;
@@ -146,7 +141,13 @@
     class Iterator : public DatasetIterator<FilterDatasetBase> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<FilterDatasetBase>(params) {}
+          : DatasetIterator<FilterDatasetBase>(params),
+            filtered_elements_(0),
+            dropped_elements_(0) {
+        std::vector<string> components =
+            str_util::Split(params.prefix, "::", str_util::SkipEmpty());
+        prefix_end_ = components.back();
+      }
 
       Status Initialize(IteratorContext* ctx) override {
         TF_RETURN_IF_ERROR(
@@ -161,6 +162,7 @@
         // `input_impl_` and `f` are thread-safe. However, if multiple
         // threads enter this method, outputs may be observed in a
         // non-deterministic order.
+        auto stats_aggregator = ctx->stats_aggregator();
         bool matched;
         do {
           {
@@ -183,8 +185,34 @@
           if (!matched) {
             // Clear the output tensor list since it didn't match.
             out_tensors->clear();
+            if (stats_aggregator) {
+              mutex_lock l(mu_);
+              dropped_elements_++;
+              stats_aggregator->AddScalar(
+                  strings::StrCat(prefix_end_, "::dropped_elements"),
+                  static_cast<float>((dropped_elements_)));
+              // TODO(shivaniagrawal): multiple pipelines would collect
+              // aggregated number of dropped elements for all the pipelines,
+              // exploit tagged_context here.
+              stats_aggregator->IncrementCounter(
+                  prefix_end_, "dropped_elements", static_cast<float>(1));
+            }
           }
         } while (!matched);
+        // TODO(shivaniagrawal): add ratio of dropped_elements and
+        // filtered_elements as a histogram.
+        if (stats_aggregator) {
+          mutex_lock l(mu_);
+          filtered_elements_++;
+          stats_aggregator->AddScalar(
+              strings::StrCat(prefix_end_, "::filtered_elements"),
+              static_cast<float>((filtered_elements_)));
+          // TODO(shivaniagrawal): multiple pipelines would collect aggregated
+          // number of filtered elements for all the pipelines, exploit
+          // tagged_context here.
+          stats_aggregator->IncrementCounter(prefix_end_, "filtered_elements",
+                                             static_cast<float>(1));
+        }
         *end_of_sequence = false;
         return Status::OK();
       }
@@ -197,6 +225,10 @@
         else
           TF_RETURN_IF_ERROR(
               writer->WriteScalar(full_name("input_impls_empty"), ""));
+        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("filtered_elements"),
+                                               filtered_elements_));
+        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("dropped_elements"),
+                                               dropped_elements_));
         return Status::OK();
       }
 
@@ -207,12 +239,19 @@
           input_impl_.reset();
         else
           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("filtered_elements"),
+                                              &filtered_elements_));
+        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("dropped_elements"),
+                                              &dropped_elements_));
         return Status::OK();
       }
 
      private:
       mutex mu_;
       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+      int64 filtered_elements_ GUARDED_BY(mu_);
+      int64 dropped_elements_ GUARDED_BY(mu_);
+      string prefix_end_;
     };
 
     const DatasetBase* const input_;
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index e3c45ef..2fada22 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -39,18 +39,9 @@
 
   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                    DatasetBase** output) override {
-    OpInputList inputs;
-    OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
-    std::vector<Tensor> other_arguments;
-    other_arguments.reserve(inputs.size());
-    for (const Tensor& t : inputs) {
-      other_arguments.push_back(t);
-    }
-
     std::unique_ptr<CapturedFunction> captured_func;
-    OP_REQUIRES_OK(ctx, CapturedFunction::Create(
-                            func_, std::move(other_arguments), &captured_func));
-
+    OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+                                                 &captured_func));
     *output = new Dataset(ctx, input, func_, std::move(captured_func),
                           output_types_, output_shapes_);
   }
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index ac5cc1b..71a3631 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -145,44 +145,18 @@
 
 void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
                                      DatasetBase** output) {
-  OpInputList init_func_other_args_input;
-  OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args",
-                                      &init_func_other_args_input));
-  std::vector<Tensor> init_func_other_args;
-  init_func_other_args.reserve(init_func_other_args_input.size());
-  for (const Tensor& t : init_func_other_args_input) {
-    init_func_other_args.push_back(t);
-  }
   std::unique_ptr<CapturedFunction> init_func;
-  OP_REQUIRES_OK(
-      ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args),
-                                    &init_func));
-
-  OpInputList next_func_other_args_input;
-  OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args",
-                                      &next_func_other_args_input));
-  std::vector<Tensor> next_func_other_args;
-  next_func_other_args.reserve(next_func_other_args_input.size());
-  for (const Tensor& t : next_func_other_args_input) {
-    next_func_other_args.push_back(t);
-  }
-  std::unique_ptr<CapturedFunction> next_func;
-  OP_REQUIRES_OK(
-      ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args),
-                                    &next_func));
-
-  OpInputList finalize_func_other_args_input;
-  OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args",
-                                      &finalize_func_other_args_input));
-  std::vector<Tensor> finalize_func_other_args;
-  finalize_func_other_args.reserve(finalize_func_other_args_input.size());
-  for (const Tensor& t : finalize_func_other_args_input) {
-    finalize_func_other_args.push_back(t);
-  }
-  std::unique_ptr<CapturedFunction> finalize_func;
   OP_REQUIRES_OK(ctx, CapturedFunction::Create(
-                          finalize_func_, std::move(finalize_func_other_args),
-                          &finalize_func));
+                          init_func_, ctx, "init_func_other_args", &init_func));
+
+  std::unique_ptr<CapturedFunction> next_func;
+  OP_REQUIRES_OK(ctx, CapturedFunction::Create(
+                          next_func_, ctx, "next_func_other_args", &next_func));
+
+  std::unique_ptr<CapturedFunction> finalize_func;
+  OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx,
+                                               "finalize_func_other_args",
+                                               &finalize_func));
 
   *output =
       new Dataset(ctx, std::move(init_func), std::move(next_func),
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index e4fa557..8b417bb 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -42,50 +42,19 @@
 
   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                    DatasetBase** output) override {
-    // Get captured inputs for the key, reduce, and window_size functions.
-    OpInputList key_func_other_argument_inputs;
-    OP_REQUIRES_OK(ctx, ctx->input_list("key_func_other_arguments",
-                                        &key_func_other_argument_inputs));
-    std::vector<Tensor> key_func_other_arguments;
-    key_func_other_arguments.reserve(key_func_other_argument_inputs.size());
-    for (const Tensor& t : key_func_other_argument_inputs) {
-      key_func_other_arguments.push_back(t);
-    }
-    OpInputList reduce_func_other_argument_inputs;
-    OP_REQUIRES_OK(ctx, ctx->input_list("reduce_func_other_arguments",
-                                        &reduce_func_other_argument_inputs));
-    std::vector<Tensor> reduce_func_other_arguments;
-    reduce_func_other_arguments.reserve(
-        reduce_func_other_argument_inputs.size());
-    for (const Tensor& t : reduce_func_other_argument_inputs) {
-      reduce_func_other_arguments.push_back(t);
-    }
-    OpInputList window_size_func_other_argument_inputs;
-    OP_REQUIRES_OK(ctx,
-                   ctx->input_list("window_size_func_other_arguments",
-                                   &window_size_func_other_argument_inputs));
-    std::vector<Tensor> window_size_func_other_arguments;
-    window_size_func_other_arguments.reserve(
-        window_size_func_other_argument_inputs.size());
-    for (const Tensor& t : window_size_func_other_argument_inputs) {
-      window_size_func_other_arguments.push_back(t);
-    }
-    // TODO(mrry): Refactor CapturedFunction to share the runtime
-    // state between multiple functions?
     std::unique_ptr<CapturedFunction> captured_key_func;
-    OP_REQUIRES_OK(ctx, CapturedFunction::Create(
-                            key_func_, std::move(key_func_other_arguments),
-                            &captured_key_func));
+    OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx,
+                                                 "key_func_other_arguments",
+                                                 &captured_key_func));
     std::unique_ptr<CapturedFunction> captured_reduce_func;
-    OP_REQUIRES_OK(
-        ctx, CapturedFunction::Create(reduce_func_,
-                                      std::move(reduce_func_other_arguments),
-                                      &captured_reduce_func));
+    OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx,
+                                                 "reduce_func_other_arguments",
+                                                 &captured_reduce_func));
     std::unique_ptr<CapturedFunction> captured_window_size_func;
-    OP_REQUIRES_OK(
-        ctx, CapturedFunction::Create(
-                 window_size_func_, std::move(window_size_func_other_arguments),
-                 &captured_window_size_func));
+    OP_REQUIRES_OK(ctx,
+                   CapturedFunction::Create(window_size_func_, ctx,
+                                            "window_size_func_other_arguments",
+                                            &captured_window_size_func));
 
     *output = new Dataset(
         ctx, input, key_func_, reduce_func_, window_size_func_,
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index 0768f46..0aa802b 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -39,14 +39,6 @@
 
   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                    DatasetBase** output) override {
-    OpInputList inputs;
-    OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
-    std::vector<Tensor> other_arguments;
-    other_arguments.reserve(inputs.size());
-    for (const Tensor& t : inputs) {
-      other_arguments.push_back(t);
-    }
-
     const Tensor* cycle_length_t;
     OP_REQUIRES_OK(ctx, ctx->input("cycle_length", &cycle_length_t));
     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(cycle_length_t->shape()),
@@ -66,8 +58,8 @@
         errors::InvalidArgument("block_length must be greater than zero."));
 
     std::unique_ptr<CapturedFunction> captured_func;
-    OP_REQUIRES_OK(ctx, CapturedFunction::Create(
-                            func_, std::move(other_arguments), &captured_func));
+    OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+                                                 &captured_func));
 
     *output =
         new Dataset(ctx, input, func_, std::move(captured_func), cycle_length,
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 85e4935..2bbf4af 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -14,6 +14,7 @@
 ==============================================================================*/
 #define EIGEN_USE_THREADS
 
+#include <atomic>
 #include <utility>
 
 #include "tensorflow/core/common_runtime/function.h"
@@ -26,6 +27,7 @@
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cpu_info.h"
 #include "tensorflow/core/platform/tracing.h"
 
 namespace tensorflow {
@@ -39,7 +41,6 @@
  public:
   explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
       : UnaryDatasetOpKernel(ctx),
-        graph_def_version_(ctx->graph_def_version()),
         op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
     OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -49,14 +50,6 @@
  protected:
   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                    DatasetBase** output) override {
-    OpInputList inputs;
-    OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
-    std::vector<Tensor> other_arguments;
-    other_arguments.reserve(inputs.size());
-    for (const Tensor& t : inputs) {
-      other_arguments.push_back(t);
-    }
-
     int64 batch_size;
     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size));
     OP_REQUIRES(
@@ -77,7 +70,8 @@
       case 2:
         OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
                                                 &num_parallel_calls));
-        OP_REQUIRES(ctx, num_parallel_calls > 0,
+        OP_REQUIRES(ctx,
+                    num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
                     errors::InvalidArgument(
                         "num_parallel_calls must be greater than zero."));
         break;
@@ -92,8 +86,8 @@
                    ParseScalarArgument(ctx, "drop_remainder", &drop_remainder));
 
     std::unique_ptr<CapturedFunction> captured_func;
-    OP_REQUIRES_OK(ctx, CapturedFunction::Create(
-                            func_, std::move(other_arguments), &captured_func));
+    OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+                                                 &captured_func));
 
     *output = new Dataset(ctx, input, batch_size, num_parallel_calls,
                           drop_remainder, output_types_, output_shapes_, func_,
@@ -190,7 +184,8 @@
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params) {}
+          : DatasetIterator<Dataset>(params),
+            num_parallel_calls_(params.dataset->num_parallel_calls_) {}
 
       ~Iterator() override {
         mutex_lock l(mu_);
@@ -204,8 +199,16 @@
       }
 
       Status Initialize(IteratorContext* ctx) override {
-        SetMetadata(ctx, "batch_size", dataset()->batch_size_);
-        SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_);
+        mutex_lock l(mu_);
+        AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
+        if (num_parallel_calls_ == kAutoTune) {
+          num_parallel_calls_ = 1;
+          AddTunableParameter(ctx, "parallelism",
+                              &num_parallel_calls_ /* value */, 1 /* min */,
+                              port::NumSchedulableCPUs() /* max */, &cond_var_);
+        } else {
+          AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+        }
         TF_RETURN_IF_ERROR(
             dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
         return dataset()->captured_func_->Instantiate(ctx);
@@ -220,14 +223,14 @@
           EnsureRunnerThreadStarted(ctx);
           while (batch_results_.empty() ||
                  batch_results_.front()->num_calls > 0) {
-            StopWork(ctx);
+            RecordStop(ctx);
             cond_var_.wait(l);
-            StartWork(ctx);
+            RecordStart(ctx);
           }
           std::swap(result, batch_results_.front());
           batch_results_.pop_front();
+          cond_var_.notify_all();
         }
-        cond_var_.notify_all();
         return ProcessResult(ctx, result, out_tensors, end_of_sequence);
       }
 
@@ -330,11 +333,9 @@
 
       void CallCompleted(const std::shared_ptr<BatchResult>& result)
           LOCKS_EXCLUDED(mu_) {
-        {
-          mutex_lock l(mu_);
-          num_calls_--;
-          result->num_calls--;
-        }
+        mutex_lock l(mu_);
+        num_calls_--;
+        result->num_calls--;
         cond_var_.notify_all();
       }
 
@@ -427,11 +428,6 @@
         result->output_allocated = true;
       }
 
-      int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-        return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) /
-               dataset()->batch_size_;
-      }
-
       Status ProcessResult(IteratorContext* ctx,
                            const std::shared_ptr<BatchResult>& result,
                            std::vector<Tensor>* out_tensors,
@@ -480,31 +476,34 @@
       void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
           LOCKS_EXCLUDED(mu_) {
         std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
-        new_calls.reserve(dataset()->num_parallel_calls_);
-        StartWork(ctx.get());
+        RecordStart(ctx.get());
         auto stop_cleanup =
-            gtl::MakeCleanup([this, &ctx]() { StopWork(ctx.get()); });
+            gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
+        new_calls.reserve(num_parallel_calls_);
+        auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+          int64 num_parallel_calls = num_parallel_calls_;
+          int64 max_batch_results =
+              (num_parallel_calls + dataset()->batch_size_ - 1) /
+              dataset()->batch_size_;
+          return num_calls_ >= num_parallel_calls ||
+                 (batch_results_.size() > max_batch_results ||
+                  (batch_results_.size() == max_batch_results &&
+                   call_counter_ % dataset()->batch_size_ == 0));
+        };
         while (true) {
           {
             mutex_lock l(mu_);
-            while (!cancelled_ &&
-                   (num_calls_ >= dataset()->num_parallel_calls_ ||
-                    batch_results_.size() > MaxBatchResults() ||
-                    (batch_results_.size() == MaxBatchResults() &&
-                     call_counter_ % dataset()->batch_size_ == 0))) {
-              StopWork(ctx.get());
+            while (!cancelled_ && busy()) {
+              RecordStop(ctx.get());
               cond_var_.wait(l);
-              StartWork(ctx.get());
+              RecordStart(ctx.get());
             }
 
             if (cancelled_) {
               return;
             }
 
-            while (num_calls_ < dataset()->num_parallel_calls_ &&
-                   (batch_results_.size() < MaxBatchResults() ||
-                    (batch_results_.size() == MaxBatchResults() &&
-                     call_counter_ % dataset()->batch_size_ != 0))) {
+            while (!busy()) {
               if (call_counter_ % dataset()->batch_size_ == 0) {
                 batch_results_.emplace_back(
                     new BatchResult(dataset()->batch_size_));
@@ -648,6 +647,8 @@
       // user specified level of parallelism and there are slots available in
       // the `batch_results_` buffer.
       condition_variable cond_var_;
+      // Identifies the maximum number of parallel calls.
+      std::atomic<int64> num_parallel_calls_;
       // Counts the number of outstanding calls for this batch.
       int64 num_calls_ GUARDED_BY(mu_) = 0;
       // Counts the total number of calls.
@@ -671,7 +672,6 @@
     const Eigen::ThreadPoolDevice* device_;  // not owned
   };
 
-  const int graph_def_version_;
   const int op_version_;
   DataTypeVector output_types_;
   std::vector<PartialTensorShape> output_shapes_;
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index af301e2..f112e1d 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -38,18 +38,10 @@
 
   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                    DatasetBase** output) override {
-    OpInputList inputs;
-    OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
-    std::vector<Tensor> other_arguments;
-    other_arguments.reserve(inputs.size());
-    for (const Tensor& t : inputs) {
-      other_arguments.push_back(t);
-    }
-
     std::unique_ptr<CapturedFunction> captured_func;
-    OP_REQUIRES_OK(ctx, CapturedFunction::Create(
-                            func_, std::move(other_arguments),
-                            use_inter_op_parallelism_, &captured_func));
+    OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+                                                 use_inter_op_parallelism_,
+                                                 &captured_func));
 
     *output = new Dataset(ctx, input, func_, std::move(captured_func),
                           output_types_, output_shapes_);
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
index c7f929d..9aa505f 100644
--- a/tensorflow/core/kernels/data/model_dataset_op.cc
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -17,11 +17,14 @@
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/kernels/data/dataset.h"
 #include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/cpu_info.h"
 
 namespace tensorflow {
 namespace data {
 namespace {
 
+const int kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMicros;
+
 class ModelDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit ModelDatasetOp(OpKernelConstruction* ctx)
@@ -71,9 +74,16 @@
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params), model_(new model::Model()) {}
+          : DatasetIterator<Dataset>(params),
+            model_(std::make_shared<model::Model>()) {}
 
-      ~Iterator() override { model_->OutputToFile(); }
+      ~Iterator() override {
+        // Signal the optimize thread to terminate it. We will then join that
+        // thread when we delete `this->optimize_thread_`.
+        mutex_lock l(mu_);
+        cancelled_ = true;
+        cond_var_.notify_all();
+      }
 
       Status Initialize(IteratorContext* ctx) override {
         IteratorContext ctx_with_model(CreateParams(ctx));
@@ -85,6 +95,7 @@
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
         mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(EnsureOptimizeThreadStarted(ctx));
         IteratorContext ctx_with_model(CreateParams(ctx));
         return input_impl_->GetNext(&ctx_with_model, out_tensors,
                                     end_of_sequence);
@@ -111,8 +122,53 @@
       }
 
      private:
+      Status EnsureOptimizeThreadStarted(IteratorContext* ctx)
+          EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+        if (!optimize_thread_) {
+          std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+          optimize_thread_.reset(ctx->env()->StartThread(
+              {}, "optimize_thread",
+              [this, new_ctx]() { OptimizeThread(new_ctx); }));
+        }
+        return Status::OK();
+      }
+
+      void OptimizeThread(const std::shared_ptr<IteratorContext>& ctx) {
+        int64 last_optimization_ms = 0;
+        int64 optimization_period_ms = 10;
+        while (true) {
+          {
+            mutex_lock l(mu_);
+            while (!cancelled_ &&
+                   last_optimization_ms + optimization_period_ms >=
+                       ctx->env()->NowMicros() / EnvTime::kMillisToMicros) {
+              cond_var_.wait_for(
+                  l, std::chrono::milliseconds(
+                         last_optimization_ms + optimization_period_ms -
+                         ctx->env()->NowMicros() / EnvTime::kMillisToMicros));
+            }
+            if (cancelled_) return;
+          }
+          model_->Optimize(port::NumSchedulableCPUs());
+          // Exponentially increase the period of running the optimization
+          // until a threshold is reached.
+          if (optimization_period_ms < kOptimizationPeriodThresholdMs) {
+            if (optimization_period_ms << 1 < kOptimizationPeriodThresholdMs) {
+              optimization_period_ms <<= 1;
+            } else {
+              optimization_period_ms = kOptimizationPeriodThresholdMs;
+            }
+          }
+          last_optimization_ms =
+              ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
+        }
+      }
+
       mutex mu_;
+      condition_variable cond_var_;
       std::shared_ptr<model::Model> model_;
+      std::unique_ptr<Thread> optimize_thread_ GUARDED_BY(mu_);
+      bool cancelled_ GUARDED_BY(mu_) = false;
       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
     };
 
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
index 6180df5..346e4ce 100644
--- a/tensorflow/core/kernels/data/optional_ops.cc
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -108,11 +108,8 @@
   void Compute(OpKernelContext* ctx) override {
     OpInputList components_input;
     OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input));
-    std::vector<Tensor> components;
-    components.reserve(components_input.size());
-    for (const Tensor& component_t : components_input) {
-      components.push_back(component_t);
-    }
+    std::vector<Tensor> components(components_input.begin(),
+                                   components_input.end());
     OP_REQUIRES_OK(
         ctx, WriteOptionalWithValueToOutput(ctx, 0, std::move(components)));
   }
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index 73eeafd..7b01c3b 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -207,7 +207,7 @@
           : DatasetIterator<Dataset>(params) {}
 
       Status Initialize(IteratorContext* ctx) override {
-        SetMetadata(ctx, "batch_size", dataset()->batch_size_);
+        AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
         return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
       }
 
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index aa5e613..2e6e046 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -12,6 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#include <atomic>
 #include <deque>
 #include <utility>
 
@@ -44,14 +45,6 @@
 
   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                    DatasetBase** output) override {
-    OpInputList inputs;
-    OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
-    std::vector<Tensor> other_arguments;
-    other_arguments.reserve(inputs.size());
-    for (const Tensor& t : inputs) {
-      other_arguments.push_back(t);
-    }
-
     int64 cycle_length = 0;
     OP_REQUIRES_OK(ctx,
                    ParseScalarArgument(ctx, "cycle_length", &cycle_length));
@@ -83,8 +76,8 @@
 
     std::unique_ptr<CapturedFunction> captured_func;
     OP_REQUIRES_OK(
-        ctx, CapturedFunction::Create(
-                 interleave_func_, std::move(other_arguments), &captured_func));
+        ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
+                                      &captured_func));
 
     *output =
         new Dataset(ctx, input, interleave_func_, std::move(captured_func),
@@ -252,7 +245,7 @@
       }
 
       Status Initialize(IteratorContext* ctx) override {
-        SetMetadata(ctx, "parallelism", dataset()->cycle_length_);
+        AddConstantParameter(ctx, "parallelism", dataset()->cycle_length_);
         TF_RETURN_IF_ERROR(
             dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
         return dataset()->captured_func_->Instantiate(ctx);
@@ -352,13 +345,13 @@
 
           if (must_wait_for_input) {
             // Wait for elements to become available.
-            StopWork(ctx);
+            RecordStop(ctx);
             if (dataset()->sloppy_) {
               sloppy_cond_var_.wait(l);
             } else {
               workers_[interleave_indices_[next_index_]].cond_var.wait(l);
             }
-            StartWork(ctx);
+            RecordStart(ctx);
           }
         }
         return errors::Cancelled(
@@ -626,11 +619,11 @@
 
         // std::function arguments are copy-constructable, so we pass raw
         // pointers, and then immediately wrap them to ensure correct ownership.
-        StartWork(ctx.get());
+        RecordStart(ctx.get());
         auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
           mutex_lock l(mu_);
           workers_[thread_index].cond_var.notify_all();
-          StopWork(ctx.get());
+          RecordStop(ctx.get());
         });
         bool make_new_iterator;
         {
@@ -668,9 +661,9 @@
             if (read_new_input) {
               mutex_lock l(mu_);
               while (!cancelled_ && !workers_[thread_index].is_producing) {
-                StopWork(ctx.get());
+                RecordStop(ctx.get());
                 workers_[thread_index].cond_var.wait(l);
-                StartWork(ctx.get());
+                RecordStart(ctx.get());
               }
               if (cancelled_) return;
               // Copy the input tensors so that we do not need to block on `mu_`
@@ -720,9 +713,9 @@
             // Wait for space in the prefetch queue.
             while (!cancelled_ && workers_[thread_index].outputs.size() ==
                                       dataset()->buffer_output_elements_) {
-              StopWork(ctx.get());
+              RecordStop(ctx.get());
               workers_[thread_index].cond_var.wait(l);
-              StartWork(ctx.get());
+              RecordStart(ctx.get());
             }
             if (cancelled_) return;
             tf_shared_lock ckpt_l(ckpt_mu_);
@@ -771,9 +764,9 @@
                 // Wait for space in the prefetch queue.
                 while (!cancelled_ && workers_[thread_index].outputs.size() ==
                                           dataset()->buffer_output_elements_) {
-                  StopWork(ctx.get());
+                  RecordStop(ctx.get());
                   workers_[thread_index].cond_var.wait(l);
-                  StartWork(ctx.get());
+                  RecordStart(ctx.get());
                 }
                 if (cancelled_) return;
 
@@ -1102,9 +1095,6 @@
 
   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                    DatasetBase** output) override {
-    OpInputList inputs;
-    OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
-
     int64 cycle_length = 0;
     OP_REQUIRES_OK(ctx,
                    ParseScalarArgument(ctx, "cycle_length", &cycle_length));
@@ -1120,7 +1110,7 @@
     int64 num_parallel_calls;
     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
                                             &num_parallel_calls));
-    OP_REQUIRES(ctx, num_parallel_calls > 0,
+    OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
                 errors::InvalidArgument(
                     "num_parallel_calls must be greater than zero."));
     OP_REQUIRES(
@@ -1128,16 +1118,10 @@
         errors::InvalidArgument(
             "num_parallel_calls must less than or equal to cycle_length."));
 
-    // TODO(b/114267189): Use `other_arguments(inputs.begin(), inputs.end());`.
-    std::vector<Tensor> other_arguments;
-    other_arguments.reserve(inputs.size());
-    for (const Tensor& t : inputs) {
-      other_arguments.push_back(t);
-    }
     std::unique_ptr<CapturedFunction> captured_func;
     OP_REQUIRES_OK(
-        ctx, CapturedFunction::Create(
-                 interleave_func_, std::move(other_arguments), &captured_func));
+        ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
+                                      &captured_func));
 
     *output = new Dataset(ctx, input, interleave_func_,
                           std::move(captured_func), cycle_length, block_length,
@@ -1230,6 +1214,7 @@
      public:
       explicit Iterator(const Params& params)
           : DatasetIterator<Dataset>(params),
+            num_parallel_calls_(params.dataset->num_parallel_calls_),
             args_list_(params.dataset->cycle_length_),
             current_elements_(params.dataset->cycle_length_),
             element_in_use_(params.dataset->cycle_length_, false),
@@ -1250,7 +1235,16 @@
       }
 
       Status Initialize(IteratorContext* ctx) override {
-        SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_);
+        mutex_lock l(mu_);
+        if (num_parallel_calls_ == kAutoTune) {
+          num_parallel_calls_ = 1;
+          AddTunableParameter(ctx, "parallelism",
+                              &num_parallel_calls_ /* value */, 1 /* min */,
+                              dataset()->cycle_length_ /* max */, &cond_var_);
+        } else {
+          AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+        }
+        AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_);
         TF_RETURN_IF_ERROR(
             dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
         return dataset()->captured_func_->Instantiate(ctx);
@@ -1266,9 +1260,9 @@
             EnsureRunnerThreadStarted(ctx);
             while (invocation_results_.empty() &&
                    (!end_of_input_ || num_open_ > 0)) {
-              StopWork(ctx);
+              RecordStop(ctx);
               cond_var_.wait(l);
-              StartWork(ctx);
+              RecordStart(ctx);
             }
             if (!invocation_results_.empty()) {
               std::swap(result, invocation_results_.front());
@@ -1277,11 +1271,11 @@
               *end_of_sequence = true;
               return Status::OK();
             }
+            cond_var_.notify_all();
           }
-          cond_var_.notify_all();
-          StopWork(ctx);
+          RecordStop(ctx);
           result->notification.WaitForNotification();
-          StartWork(ctx);
+          RecordStart(ctx);
         } while (result->skip);
 
         if (result->status.ok()) {
@@ -1405,8 +1399,8 @@
           const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
           const std::vector<std::shared_ptr<InvocationResult>>& results)
           LOCKS_EXCLUDED(mu_) {
-        StartWork(ctx.get());
-        auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+        RecordStart(ctx.get());
+        auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
         bool end_of_input = false;
         for (auto& result : results) {
           if (!end_of_input) {
@@ -1424,60 +1418,66 @@
 
         // Release the ownership of the cycle element iterator, closing the
         // iterator if end of input was encountered.
-        {
-          if (end_of_input) {
-            current_elements_[cycle_index].reset();
-          }
-          mutex_lock l(mu_);
-          element_in_use_[cycle_index] = false;
-          num_calls_--;
-          if (end_of_input) {
-            args_list_[cycle_index].clear();
-            num_open_--;
-          }
+        if (end_of_input) {
+          current_elements_[cycle_index].reset();
+        }
+        mutex_lock l(mu_);
+        element_in_use_[cycle_index] = false;
+        num_calls_--;
+        if (end_of_input) {
+          args_list_[cycle_index].clear();
+          num_open_--;
         }
         cond_var_.notify_all();
       }
 
-      int64 MaxInvocationResults() {
-        return dataset()->cycle_length_ * dataset()->block_length_;
-      }
-
       // Method responsible for 1) creating iterators out of input elements, 2)
       // determining the order in which elements are fetched from the iterators,
       // and 3) scheduling the fetching of the elements to a threadpool.
       //
       // This method runs in the `runner_thread` background thread.
       void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
-        StartWork(ctx.get());
-        auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+        RecordStart(ctx.get());
+        auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
+        auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+          return element_in_use_[cycle_index_] ||
+                 num_calls_ >= num_parallel_calls_ ||
+                 invocation_results_.size() >=
+                     dataset()->cycle_length_ * dataset()->block_length_;
+        };
         while (true) {
-          {
-            mutex_lock l(mu_);
-            // Wait until this thread is cancelled, the end of input has been
-            // reached, or the cycle element at the `cycle_index_` position is
-            // not in use and there is space in the `invocation_results_` queue.
-            while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
-                   (element_in_use_[cycle_index_] ||
-                    num_calls_ >= dataset()->num_parallel_calls_ ||
-                    invocation_results_.size() >= MaxInvocationResults())) {
-              StopWork(ctx.get());
-              cond_var_.wait(l);
-              StartWork(ctx.get());
-            }
+          mutex_lock l(mu_);
+          // Wait until this thread is cancelled, the end of input has been
+          // reached, or the cycle element at the `cycle_index_` position is
+          // not in use and there is space in the `invocation_results_` queue.
+          while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) {
+            RecordStop(ctx.get());
+            cond_var_.wait(l);
+            RecordStart(ctx.get());
+          }
 
-            if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
-              return;
-            }
+          if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
+            return;
+          }
 
-            while (!element_in_use_[cycle_index_] &&
-                   (!end_of_input_ || num_open_ > 0) &&
-                   num_calls_ < dataset()->num_parallel_calls_ &&
-                   invocation_results_.size() < MaxInvocationResults()) {
-              if (!current_elements_[cycle_index_]) {
-                // Try to create a new iterator from the next input element.
-                Status status = input_impl_->GetNext(
-                    ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+          while ((!end_of_input_ || num_open_ > 0) && !busy()) {
+            if (!current_elements_[cycle_index_]) {
+              // Try to create a new iterator from the next input element.
+              Status status = input_impl_->GetNext(
+                  ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+              if (!status.ok()) {
+                invocation_results_.emplace_back(new InvocationResult());
+                std::shared_ptr<InvocationResult>& result =
+                    invocation_results_.back();
+                result->status.Update(status);
+                result->notification.Notify();
+                break;
+              }
+              if (!end_of_input_) {
+                Status status = MakeIteratorFromInputElement(
+                    ctx.get(), args_list_[cycle_index_], cycle_index_,
+                    dataset()->captured_func_.get(), prefix(),
+                    &current_elements_[cycle_index_]);
                 if (!status.ok()) {
                   invocation_results_.emplace_back(new InvocationResult());
                   std::shared_ptr<InvocationResult>& result =
@@ -1486,39 +1486,25 @@
                   result->notification.Notify();
                   break;
                 }
-                if (!end_of_input_) {
-                  Status status = MakeIteratorFromInputElement(
-                      ctx.get(), args_list_[cycle_index_], cycle_index_,
-                      dataset()->captured_func_.get(), prefix(),
-                      &current_elements_[cycle_index_]);
-                  if (!status.ok()) {
-                    invocation_results_.emplace_back(new InvocationResult());
-                    std::shared_ptr<InvocationResult>& result =
-                        invocation_results_.back();
-                    result->status.Update(status);
-                    result->notification.Notify();
-                    break;
-                  }
-                  ++num_open_;
-                }
+                ++num_open_;
               }
-              if (current_elements_[cycle_index_]) {
-                // Pre-allocate invocation results for outputs to be fetched
-                // and then fetch the outputs asynchronously.
-                std::vector<std::shared_ptr<InvocationResult>> results;
-                results.reserve(dataset()->block_length_);
-                for (int i = 0; i < dataset()->block_length_; ++i) {
-                  invocation_results_.emplace_back(new InvocationResult());
-                  results.push_back(invocation_results_.back());
-                }
-                num_calls_++;
-                element_in_use_[cycle_index_] = true;
-                thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
-                                                 ctx, cycle_index_,
-                                                 std::move(results)));
-              }
-              cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
             }
+            if (current_elements_[cycle_index_]) {
+              // Pre-allocate invocation results for outputs to be fetched
+              // and then fetch the outputs asynchronously.
+              std::vector<std::shared_ptr<InvocationResult>> results;
+              results.reserve(dataset()->block_length_);
+              for (int i = 0; i < dataset()->block_length_; ++i) {
+                invocation_results_.emplace_back(new InvocationResult());
+                results.push_back(invocation_results_.back());
+              }
+              num_calls_++;
+              element_in_use_[cycle_index_] = true;
+              thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
+                                               ctx, cycle_index_,
+                                               std::move(results)));
+            }
+            cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
           }
           cond_var_.notify_all();
         }
@@ -1621,6 +1607,9 @@
       // and there are elements left to be fetched.
       condition_variable cond_var_;
 
+      // Identifies the maximum number of parallel calls.
+      std::atomic<int64> num_parallel_calls_;
+
       // Iterator for input elements.
       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
 
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 0795987..6abe6c8 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -44,25 +44,17 @@
  protected:
   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                    DatasetBase** output) override {
-    OpInputList inputs;
-    OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
-    std::vector<Tensor> other_arguments;
-    other_arguments.reserve(inputs.size());
-    for (const Tensor& t : inputs) {
-      other_arguments.push_back(t);
-    }
-
     int32 num_parallel_calls;
     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
                                             &num_parallel_calls));
-    OP_REQUIRES(ctx, num_parallel_calls > 0,
+    OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
                 errors::InvalidArgument(
                     "num_parallel_calls must be greater than zero."));
 
     std::unique_ptr<CapturedFunction> captured_func;
-    OP_REQUIRES_OK(ctx, CapturedFunction::Create(
-                            func_, std::move(other_arguments),
-                            use_inter_op_parallelism_, &captured_func));
+    OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+                                                 use_inter_op_parallelism_,
+                                                 &captured_func));
 
     *output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
                           output_shapes_, use_inter_op_parallelism_,
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 0b6e587..ee20249b 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -14,12 +14,14 @@
 ==============================================================================*/
 #include "tensorflow/core/kernels/data/parallel_map_iterator.h"
 
+#include <atomic>
 #include <deque>
 #include <functional>
 #include <utility>
 #include <vector>
 
 #include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/cpu_info.h"
 
 namespace tensorflow {
 namespace data {
@@ -39,11 +41,6 @@
         num_parallel_calls_(num_parallel_calls) {}
 
   ~ParallelMapIterator() override {
-    // TODO(mrry): Replace this cancellation logic with a
-    // CancellationManager. The syntax would be more heavyweight,
-    // but it would be possible to thread a cancellation manager
-    // through the IteratorContext to upstream,
-    // potentially-blocking iterators, when we add these.
     mutex_lock l(mu_);
     // Cancel the runner thread.
     cancelled_ = true;
@@ -55,7 +52,17 @@
   }
 
   Status Initialize(IteratorContext* ctx) override {
-    SetMetadata(ctx, "parallelism", num_parallel_calls_);
+    mutex_lock l(mu_);
+    if (num_parallel_calls_ == kAutoTune) {
+      num_parallel_calls_ = 1;
+      // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
+      // use it here for the maximum.
+      AddTunableParameter(ctx, "parallelism", &num_parallel_calls_ /* value */,
+                          1 /* min */, port::NumSchedulableCPUs() /* max */,
+                          &cond_var_);
+    } else {
+      AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+    }
     TF_RETURN_IF_ERROR(
         input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
     if (init_func_) {
@@ -71,17 +78,17 @@
       mutex_lock l(mu_);
       EnsureRunnerThreadStarted(ctx);
       while (invocation_results_.empty()) {
-        StopWork(ctx);
+        RecordStop(ctx);
         cond_var_.wait(l);
-        StartWork(ctx);
+        RecordStart(ctx);
       }
       std::swap(result, invocation_results_.front());
       invocation_results_.pop_front();
+      cond_var_.notify_all();
     }
-    cond_var_.notify_all();
-    StopWork(ctx);
+    RecordStop(ctx);
     result->notification.WaitForNotification();
-    StartWork(ctx);
+    RecordStart(ctx);
     return ProcessResult(result, out_tensors, end_of_sequence);
   }
 
@@ -182,9 +189,9 @@
     {
       mutex_lock l(mu_);
       num_calls_--;
+      cond_var_.notify_all();
     }
     result->notification.Notify();
-    cond_var_.notify_all();
   }
 
   void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
@@ -199,9 +206,8 @@
       return;
     }
 
-    // Call `func_(input_element)`, store the result in
-    // `result->return_values`, and notify `result->notification` to unblock
-    // a consumer.
+    // Call `func_(input_element)`, store the result in `result->return_values`,
+    // and notify `result->notification` to unblock a consumer.
     auto done = [this, result](Status status) {
       result->status.Update(status);
       CallCompleted(result);
@@ -211,8 +217,6 @@
               std::move(done));
   }
 
-  int64 MaxInvocationResults() { return num_parallel_calls_; }
-
   Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
                        std::vector<Tensor>* out_tensors,
                        bool* end_of_sequence) {
@@ -232,31 +236,33 @@
   }
 
   void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
-    StartWork(ctx.get());
-    auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+    RecordStart(ctx.get());
+    auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
     std::vector<std::shared_ptr<InvocationResult>> new_calls;
     new_calls.reserve(num_parallel_calls_);
+    auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+      int64 num_parallel_calls = num_parallel_calls_;
+      return num_calls_ >= num_parallel_calls ||
+             invocation_results_.size() >= num_parallel_calls;
+    };
     while (true) {
       {
         mutex_lock l(mu_);
-        while (!cancelled_ &&
-               (num_calls_ >= num_parallel_calls_ ||
-                invocation_results_.size() >= MaxInvocationResults())) {
-          StopWork(ctx.get());
+        while (!cancelled_ && busy()) {
+          RecordStop(ctx.get());
           cond_var_.wait(l);
-          StartWork(ctx.get());
+          RecordStart(ctx.get());
         }
         if (cancelled_) {
           return;
         }
-        while (num_calls_ < num_parallel_calls_ &&
-               invocation_results_.size() < MaxInvocationResults()) {
+        while (!busy()) {
           invocation_results_.emplace_back(new InvocationResult());
           new_calls.push_back(invocation_results_.back());
           num_calls_++;
         }
+        cond_var_.notify_all();
       }
-      cond_var_.notify_all();
       for (const auto& call : new_calls) {
         CallFunction(ctx, call);
       }
@@ -305,7 +311,6 @@
   const DatasetBase* const input_dataset_;  // Not owned.
   const std::function<Status(IteratorContext*)> init_func_;
   const ParallelMapIteratorFunction map_func_;
-  const int32 num_parallel_calls_;
   // Used for coordination between the main thread and the runner thread.
   mutex mu_;
   // Used for coordination between the main thread and the runner thread. In
@@ -314,6 +319,8 @@
   // parallelism and there are slots available in the `invocation_results_`
   // buffer.
   condition_variable cond_var_;
+  // Identifies the maximum number of parallel calls.
+  std::atomic<int64> num_parallel_calls_;
   // Counts the number of outstanding calls.
   int64 num_calls_ GUARDED_BY(mu_) = 0;
   std::unique_ptr<IteratorBase> input_impl_;
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index 0cf5db0..c28c06d 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -87,11 +87,8 @@
                     "Expected len(dense_defaults) == len(dense_keys) but got: ",
                     dense_default_tensors.size(), " vs. ", dense_keys_.size()));
 
-    std::vector<Tensor> dense_defaults;
-    dense_defaults.reserve(dense_default_tensors.size());
-    for (const Tensor& dense_default_t : dense_default_tensors) {
-      dense_defaults.push_back(dense_default_t);
-    }
+    std::vector<Tensor> dense_defaults(dense_default_tensors.begin(),
+                                       dense_default_tensors.end());
 
     for (int d = 0; d < dense_keys_.size(); ++d) {
       const Tensor& def_value = dense_defaults[d];
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc
index 533d0bd..da35733 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc
@@ -26,6 +26,13 @@
   }
 }
 
+namespace {
+// Determines what strategy to use for increasing the buffer size limit. For
+// limits less than the threshold, an exponential increase is used, while for
+// limits greater than or equal to the threshold, a linear increase is used.
+size_t kBufferLimitThreshold = 2048;
+}  // namespace
+
 void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
   switch (mode_) {
     case Mode::kDisabled:
@@ -37,7 +44,11 @@
       return;
     case Mode::kDownswing:
       if (current_buffer_size == 0) {
-        buffer_limit_ *= 2;  // Increase the buffer size.
+        if (buffer_limit_ >= kBufferLimitThreshold) {
+          buffer_limit_ += kBufferLimitThreshold;
+        } else {
+          buffer_limit_ *= 2;
+        }
         mode_ = Mode::kUpswing;
       }
       return;
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 52c421c..754ed77 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -103,18 +103,18 @@
     Status GetNextInternal(IteratorContext* ctx,
                            std::vector<Tensor>* out_tensors,
                            bool* end_of_sequence) override {
+      auto stats_aggregator = ctx->stats_aggregator();
       {
         mutex_lock l(mu_);
-        auto stats_aggregator = ctx->stats_aggregator();
         TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
         // Wait until the next element in the buffer has been
         // produced, or we are shutting down.
         while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
                auto_tuner_.buffer_limit() != 0) {
           auto_tuner_.RecordEmpty();
-          StopWork(ctx);
+          RecordStop(ctx);
           cond_var_.wait(l);
-          StartWork(ctx);
+          RecordStart(ctx);
         }
 
         if (cancelled_) {
@@ -136,6 +136,14 @@
 
       mutex_lock parent_l(parent_mu_);
       mutex_lock l(mu_);
+      if (stats_aggregator) {
+        stats_aggregator->AddScalar(
+            strings::StrCat(prefix_end_, "::buffer_size"),
+            static_cast<float>(buffer_.size()));
+        stats_aggregator->AddScalar(
+            strings::StrCat(prefix_end_, "::buffer_capacity"),
+            static_cast<float>(auto_tuner_.buffer_limit()));
+      }
       return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
     }
 
@@ -219,6 +227,12 @@
             strings::StrCat(prefix_end_, "::buffer_utilization"),
             {static_cast<float>(buffer_.size()) /
              static_cast<float>(auto_tuner_.buffer_limit())});
+        stats_aggregator->AddScalar(
+            strings::StrCat(prefix_end_, "::buffer_size"),
+            static_cast<float>(buffer_.size()));
+        stats_aggregator->AddScalar(
+            strings::StrCat(prefix_end_, "::buffer_capacity"),
+            static_cast<float>(auto_tuner_.buffer_limit()));
       }
       // A new element is available. Forward the status from computing it, and
       // (if we successfully got an element) the output values.
@@ -255,8 +269,8 @@
     //
     // It owns the iterator context passed to it.
     void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
-      StartWork(ctx.get());
-      auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+      RecordStart(ctx.get());
+      auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
       while (true) {
         std::vector<Tensor> value;
 
@@ -264,9 +278,9 @@
         {
           mutex_lock l(mu_);
           while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) {
-            StopWork(ctx.get());
+            RecordStop(ctx.get());
             cond_var_.wait(l);
-            StartWork(ctx.get());
+            RecordStart(ctx.get());
           }
 
           if (cancelled_) {
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index 6e515d6..dbe31f3 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -45,23 +45,12 @@
     OpInputList initial_state_inputs;
     OP_REQUIRES_OK(ctx,
                    ctx->input_list("initial_state", &initial_state_inputs));
-    std::vector<Tensor> initial_state;
-    initial_state.reserve(initial_state_inputs.size());
-    for (const Tensor& t : initial_state_inputs) {
-      initial_state.push_back(t);
-    }
-
-    OpInputList inputs;
-    OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
-    std::vector<Tensor> other_arguments;
-    other_arguments.reserve(inputs.size());
-    for (const Tensor& t : inputs) {
-      other_arguments.push_back(t);
-    }
+    std::vector<Tensor> initial_state(initial_state_inputs.begin(),
+                                      initial_state_inputs.end());
 
     std::unique_ptr<CapturedFunction> captured_func;
-    OP_REQUIRES_OK(ctx, CapturedFunction::Create(
-                            func_, std::move(other_arguments), &captured_func));
+    OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+                                                 &captured_func));
 
     *output = new Dataset(ctx, input, func_, std::move(initial_state),
                           std::move(captured_func), state_types_, output_types_,
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index e1cefd2..ca4ea25 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -33,11 +33,7 @@
     OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
     // TODO(mrry): Validate that the shapes of the "components" tensors match
     // the "shapes" attr.;
-    std::vector<Tensor> components;
-    components.reserve(inputs.size());
-    for (const Tensor& t : inputs) {
-      components.push_back(t);
-    }
+    std::vector<Tensor> components(inputs.begin(), inputs.end());
     *output = new Dataset(ctx, std::move(components));
   }
 
diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc
index 3975086..ac44623 100644
--- a/tensorflow/core/kernels/data/window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op.cc
@@ -33,22 +33,44 @@
   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
                    DatasetBase** output) override {
     int64 window_size = 0;
-    OP_REQUIRES_OK(
-        ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
+    OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "size", &window_size));
     OP_REQUIRES(
         ctx, window_size > 0,
         errors::InvalidArgument("Window size must be greater than zero."));
 
-    *output = new Dataset(ctx, window_size, input);
+    int64 window_shift = 0;
+    OP_REQUIRES_OK(ctx,
+                   ParseScalarArgument<int64>(ctx, "shift", &window_shift));
+    OP_REQUIRES(
+        ctx, window_shift > 0,
+        errors::InvalidArgument("Window shift must be greater than zero."));
+
+    int64 window_stride = 0;
+    OP_REQUIRES_OK(ctx,
+                   ParseScalarArgument<int64>(ctx, "stride", &window_stride));
+    OP_REQUIRES(
+        ctx, window_stride > 0,
+        errors::InvalidArgument("Window stride must be greater than zero."));
+
+    bool drop_remainder;
+    OP_REQUIRES_OK(
+        ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder));
+
+    *output = new Dataset(ctx, input, window_size, window_shift, window_stride,
+                          drop_remainder);
   }
 
  private:
   class Dataset : public DatasetBase {
    public:
-    Dataset(OpKernelContext* ctx, int64 window_size, const DatasetBase* input)
+    Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 window_size,
+            int64 window_shift, int64 window_stride, bool drop_remainder)
         : DatasetBase(DatasetContext(ctx)),
+          input_(input),
           window_size_(window_size),
-          input_(input) {
+          window_shift_(window_shift),
+          window_stride_(window_stride),
+          drop_remainder_(drop_remainder) {
       input_->Ref();
     }
 
@@ -72,7 +94,8 @@
     }
 
     string DebugString() const override {
-      return strings::StrCat("WindowDatasetOp(", window_size_, ")::Dataset");
+      return strings::StrCat("WindowDatasetOp(", window_size_, window_shift_,
+                             window_stride_, drop_remainder_, ")::Dataset");
     }
 
    protected:
@@ -81,10 +104,19 @@
                               Node** output) const override {
       Node* input_graph_node = nullptr;
       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
-      Node* window_size = nullptr;
-      TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size));
+      Node* window_size_node = nullptr;
+      TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node));
+      Node* window_shift_node = nullptr;
+      TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node));
+      Node* window_stride_node = nullptr;
+      TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node));
+      Node* drop_remainder_node = nullptr;
+      TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
       TF_RETURN_IF_ERROR(
-          b->AddDataset(this, {input_graph_node, window_size}, output));
+          b->AddDataset(this,
+                        {input_graph_node, window_size_node, window_shift_node,
+                         window_stride_node, drop_remainder_node},
+                        output));
       return Status::OK();
     }
 
@@ -101,37 +133,79 @@
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
-        // Each row of `window_elements` is a tuple of tensors from the
-        // input iterator.
+        const int64 window_size = dataset()->window_size_;
+        const int64 window_shift = dataset()->window_shift_;
+        const int64 window_stride = dataset()->window_stride_;
         std::vector<std::vector<Tensor>> window_elements;
+        Status status = Status::OK();
         {
           mutex_lock l(mu_);
-          if (!input_impl_) {
+          if (!input_impl_ && buffer_.empty()) {
             *end_of_sequence = true;
             return Status::OK();
           }
-          window_elements.reserve(dataset()->window_size_);
-          *end_of_sequence = false;
-          for (int i = 0; i < dataset()->window_size_ && !*end_of_sequence;
-               ++i) {
-            std::vector<Tensor> window_element_tuple;
-            TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &window_element_tuple,
-                                                    end_of_sequence));
-            if (!*end_of_sequence) {
-              window_elements.emplace_back(std::move(window_element_tuple));
-            } else {
-              input_impl_.reset();
+
+          // Add elements to the buffer.
+          size_t target_size = TargetBufferSize(window_size, window_stride);
+          if (input_impl_) {
+            *end_of_sequence = false;
+            for (size_t i = buffer_.size();
+                 i < target_size && !*end_of_sequence; ++i) {
+              std::vector<Tensor> element;
+              Status status =
+                  input_impl_->GetNext(ctx, &element, end_of_sequence);
+              if (!*end_of_sequence) {
+                buffer_.emplace_back(std::move(element), status);
+              } else {
+                input_impl_.reset();
+              }
             }
           }
+
+          // If there are not enough elements and `drop_remainder` is set, we do
+          // not wish to return a smaller window.
+          if (buffer_.empty() ||
+              (dataset()->drop_remainder_ && buffer_.size() < target_size)) {
+            DCHECK(*end_of_sequence);
+            return Status::OK();
+          }
+
+          int num_elements = 1 + (buffer_.size() - 1) / window_stride;
+          window_elements.reserve(num_elements);
+          for (size_t i = 0; i < num_elements; ++i) {
+            status.Update(buffer_[window_stride * i].status);
+            if (!status.ok()) {
+              break;
+            }
+            window_elements.emplace_back(buffer_[window_stride * i].result);
+          }
+
+          // Shift the window, discarding elements if necessary.
+          int buffer_size = buffer_.size();
+          if (window_shift >= buffer_size) {
+            for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) {
+              bool end_of_input;
+              std::vector<Tensor> element;
+              // Ignore non-error status of discarded elements.
+              input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError();
+              if (end_of_input) {
+                input_impl_.reset();
+              }
+            }
+            buffer_.clear();
+          } else {
+            buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift);
+          }
         }
 
-        if (window_elements.empty()) {
-          DCHECK(*end_of_sequence);
-          return Status::OK();
+        if (!status.ok()) {
+          return status;
         }
 
+        // Construct output tensors.
         const size_t num_tuple_components = window_elements[0].size();
         const int64 num_window_elements = window_elements.size();
+        *end_of_sequence = false;
         for (size_t idx = 0; idx < num_tuple_components; ++idx) {
           DatasetBase* window_dataset;
           std::vector<std::vector<Tensor>> window_component_elements;
@@ -154,7 +228,6 @@
           TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(window_dataset,
                                                          &out_tensors->back()));
         }
-        *end_of_sequence = false;
         return Status::OK();
       }
 
@@ -167,6 +240,20 @@
         } else {
           TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
         }
+        // Save buffer.
+        TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"),
+                                               buffer_.size()));
+        for (int64 i = 0; i < buffer_.size(); i++) {
+          TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status));
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(strings::StrCat("buffer[", i, "].size"),
+                                  buffer_[i].result.size()));
+          for (int64 j = 0; j < buffer_[i].result.size(); j++) {
+            TF_RETURN_IF_ERROR(
+                writer->WriteTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+                                    buffer_[i].result[j]));
+          }
+        }
         return Status::OK();
       }
 
@@ -178,22 +265,92 @@
         } else {
           input_impl_.reset();
         }
+        // Restore buffer.
+        int64 buffer_size;
+        TF_RETURN_IF_ERROR(
+            reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size));
+        buffer_.resize(buffer_size);
+        for (int64 i = 0; i < buffer_size; i++) {
+          int64 vector_size;
+          TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status));
+          TF_RETURN_IF_ERROR(reader->ReadScalar(
+              strings::StrCat("buffer[", i, "].size"), &vector_size));
+          buffer_[i].result.resize(vector_size);
+          for (int64 j = 0; j < vector_size; j++) {
+            TF_RETURN_IF_ERROR(
+                reader->ReadTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+                                   &buffer_[i].result[j]));
+          }
+        }
         return Status::OK();
       }
 
      private:
+      struct InvocationResult {
+        InvocationResult() = default;
+        InvocationResult(std::vector<Tensor>&& result, const Status& status)
+            : result(result), status(status) {}
+
+        std::vector<Tensor> result;
+        Status status;
+      };
+
+      Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+                               const Status& status)
+          EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+        TF_RETURN_IF_ERROR(writer->WriteScalar(
+            CodeKey(index), static_cast<int64>(status.code())));
+        if (!status.ok()) {
+          TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+                                                 status.error_message()));
+        }
+        return Status::OK();
+      }
+
+      Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+                              Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+        int64 code_int;
+        TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+        error::Code code = static_cast<error::Code>(code_int);
+
+        if (code != error::Code::OK) {
+          string error_message;
+          TF_RETURN_IF_ERROR(
+              reader->ReadScalar(ErrorMessageKey(index), &error_message));
+          *status = Status(code, error_message);
+        } else {
+          *status = Status::OK();
+        }
+        return Status::OK();
+      }
+
+      string CodeKey(size_t index) {
+        return full_name(strings::StrCat("buffer[", index, "].code"));
+      }
+
+      string ErrorMessageKey(size_t index) {
+        return full_name(strings::StrCat("buffer[", index, "].error_message"));
+      }
+
+      size_t TargetBufferSize(int64 window_size, int64 window_stride) {
+        return (window_size - 1) * window_stride + 1;
+      }
+
       mutex mu_;
+      std::deque<InvocationResult> buffer_ GUARDED_BY(mu_);
       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
     };
 
-    const int64 window_size_;
     const DatasetBase* const input_;
+    const int64 window_size_;
+    const int64 window_shift_;
+    const int64 window_stride_;
+    const bool drop_remainder_;
   };
 };
 
 REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU),
                         WindowDatasetOp);
-
 }  // namespace
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc
index 750efca..ae451be 100644
--- a/tensorflow/core/kernels/decode_bmp_op.cc
+++ b/tensorflow/core/kernels/decode_bmp_op.cc
@@ -91,8 +91,10 @@
                 errors::InvalidArgument(
                     "Number of channels must be 1, 3 or 4, was ", channels_));
 
-    OP_REQUIRES(context, width > 0 && header_size >= 0,
+    OP_REQUIRES(context, width > 0,
                 errors::InvalidArgument("Width must be positive"));
+    OP_REQUIRES(context, height != 0,
+                errors::InvalidArgument("Height must be nonzero"));
     OP_REQUIRES(context, header_size >= 0,
                 errors::InvalidArgument("header size must be nonnegative"));
 
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 2a25459..76afd6f 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -17,7 +17,7 @@
 #define EIGEN_USE_GPU
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/util_ptx.cuh"
+#include "third_party/cub/util_ptx.cuh"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/kernels/depthwise_conv_op.h"
 #include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
index 862a977..e7882ac 100644
--- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
@@ -35,10 +35,10 @@
 
 #define EIGEN_USE_GPU
 
-#include "external/cub_archive/cub/device/device_radix_sort.cuh"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/iterator/constant_input_iterator.cuh"
-#include "external/cub_archive/cub/thread/thread_operators.cuh"
+#include "third_party/cub/device/device_radix_sort.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/iterator/constant_input_iterator.cuh"
+#include "third_party/cub/thread/thread_operators.cuh"
 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/core/kernels/eigen_volume_patch.h b/tensorflow/core/kernels/eigen_volume_patch.h
index a3d7958..80ab745 100644
--- a/tensorflow/core/kernels/eigen_volume_patch.h
+++ b/tensorflow/core/kernels/eigen_volume_patch.h
@@ -43,6 +43,7 @@
     IsAligned = false,
     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
     BlockAccess = false,
+    PreferBlockAccess = false,
     Layout = TensorEvaluator<ArgType, Device>::Layout,
     CoordAccess = NumDims == 6,
     RawAccess = false
diff --git a/tensorflow/core/kernels/fuzzing/BUILD b/tensorflow/core/kernels/fuzzing/BUILD
index 8bfa403..f2e0b25 100644
--- a/tensorflow/core/kernels/fuzzing/BUILD
+++ b/tensorflow/core/kernels/fuzzing/BUILD
@@ -43,4 +43,6 @@
 
 tf_ops_fuzz_target_lib("parse_tensor_op")
 
+tf_ops_fuzz_target_lib("decode_compressed")
+
 tf_ops_fuzz_target_lib("decode_json_example")
diff --git a/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc
new file mode 100644
index 0000000..0a56f4b
--- /dev/null
+++ b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
+
+namespace tensorflow {
+namespace fuzzing {
+
+class FuzzDecodeCompressed : public FuzzStringInputOp {
+  void BuildGraph(const Scope& scope) override {
+    auto input =
+        tensorflow::ops::Placeholder(scope.WithOpName("input1"), DT_STRING);
+    auto d1 = tensorflow::ops::DecodeCompressed(
+        scope.WithOpName("d1"), input,
+        tensorflow::ops::DecodeCompressed::CompressionType(""));
+    auto d2 = tensorflow::ops::DecodeCompressed(
+        scope.WithOpName("d2"), input,
+        tensorflow::ops::DecodeCompressed::CompressionType("ZLIB"));
+    auto d3 = tensorflow::ops::DecodeCompressed(
+        scope.WithOpName("d3"), input,
+        tensorflow::ops::DecodeCompressed::CompressionType("GZIP"));
+    Scope grouper =
+        scope.WithControlDependencies(std::vector<tensorflow::Operation>{
+            d1.output.op(), d2.output.op(), d3.output.op()});
+    (void)tensorflow::ops::NoOp(grouper.WithOpName("output"));
+  }
+};
+
+STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeCompressed);
+
+}  // namespace fuzzing
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/histogram_op_gpu.cu.cc b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
index a88e9b0..374a058 100644
--- a/tensorflow/core/kernels/histogram_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
@@ -18,7 +18,7 @@
 #define EIGEN_USE_GPU
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc
index 6b6a14e..1ded012 100644
--- a/tensorflow/core/kernels/logging_ops.cc
+++ b/tensorflow/core/kernels/logging_ops.cc
@@ -14,6 +14,7 @@
 ==============================================================================*/
 
 #include <iostream>
+#include "absl/strings/str_split.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/strings/str_util.h"
@@ -90,6 +91,59 @@
 
 REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp);
 
+class PrintV2Op : public OpKernel {
+ public:
+  explicit PrintV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_stream", &output_stream_));
+
+    auto output_stream_index =
+        std::find(std::begin(valid_output_streams_),
+                  std::end(valid_output_streams_), output_stream_);
+
+    if (output_stream_index == std::end(valid_output_streams_)) {
+      string error_msg = strings::StrCat(
+          "Unknown output stream: ", output_stream_, ", Valid streams are:");
+      for (auto valid_stream : valid_output_streams_) {
+        strings::StrAppend(&error_msg, " ", valid_stream);
+      }
+      OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+    }
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor* input_;
+    OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
+    const string& msg = input_->scalar<string>()();
+
+    if (output_stream_ == "stdout") {
+      std::cout << msg << std::endl;
+    } else if (output_stream_ == "stderr") {
+      std::cerr << msg << std::endl;
+    } else if (output_stream_ == "log(info)") {
+      LOG(INFO) << msg << std::endl;
+    } else if (output_stream_ == "log(warning)") {
+      LOG(WARNING) << msg << std::endl;
+    } else if (output_stream_ == "log(error)") {
+      LOG(ERROR) << msg << std::endl;
+    } else {
+      string error_msg = strings::StrCat(
+          "Unknown output stream: ", output_stream_, ", Valid streams are:");
+      for (auto valid_stream : valid_output_streams_) {
+        strings::StrAppend(&error_msg, " ", valid_stream);
+      }
+      OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+    }
+  }
+
+  const char* valid_output_streams_[6] = {"stdout", "stderr", "log(info)",
+                                          "log(warning)", "log(error)"};
+
+ private:
+  string output_stream_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("PrintV2").Device(DEVICE_CPU), PrintV2Op);
+
 class TimestampOp : public OpKernel {
  public:
   explicit TimestampOp(OpKernelConstruction* context) : OpKernel(context) {}
diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc
index 5e6958f..a259d99 100644
--- a/tensorflow/core/kernels/logging_ops_test.cc
+++ b/tensorflow/core/kernels/logging_ops_test.cc
@@ -23,11 +23,33 @@
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/kernels/ops_testutil.h"
 #include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 
 namespace tensorflow {
 namespace {
 
+class PrintingV2GraphTest : public OpsTestBase {
+ protected:
+  Status Init(const string& output_stream = "log(warning)") {
+    TF_CHECK_OK(NodeDefBuilder("op", "PrintV2")
+                    .Input(FakeInput(DT_STRING))
+                    .Attr("output_stream", output_stream)
+                    .Finalize(node_def()));
+    return InitOp();
+  }
+};
+
+TEST_F(PrintingV2GraphTest, StringSuccess) {
+  TF_ASSERT_OK(Init());
+  AddInputFromArray<string>(TensorShape({}), {"bar"});
+  TF_ASSERT_OK(RunOpKernel());
+}
+
+TEST_F(PrintingV2GraphTest, InvalidOutputStream) {
+  ASSERT_NE(::tensorflow::Status::OK(), (Init("invalid_output_stream")));
+}
+
 class PrintingGraphTest : public OpsTestBase {
  protected:
   Status Init(DataType input_type1, DataType input_type2, string msg = "",
diff --git a/tensorflow/core/kernels/mirror_pad_op.h b/tensorflow/core/kernels/mirror_pad_op.h
index cc4b694..62aa7d5 100644
--- a/tensorflow/core/kernels/mirror_pad_op.h
+++ b/tensorflow/core/kernels/mirror_pad_op.h
@@ -103,6 +103,7 @@
     IsAligned = false,
     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
     BlockAccess = false,
+    PreferBlockAccess = false,
     Layout = TensorEvaluator<ArgType, Device>::Layout,
     CoordAccess = true,
     RawAccess = false
diff --git a/tensorflow/core/kernels/mkl_conv_ops_test.cc b/tensorflow/core/kernels/mkl_conv_ops_test.cc
new file mode 100644
index 0000000..a055351
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_conv_ops_test.cc
@@ -0,0 +1,407 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#include "third_party/intel_mkl_dnn/include/mkldnn.h"
+#include "tensorflow/core/util/mkl_util.h"
+#endif
+
+// TODO(ezhulenev): Add numerical tests that will compare results of default
+// (aka Eigen) convolutions with MKL convolutions.
+
+// -------------------------------------------------------------------------- //
+// Performance Benchmarks.                                                    //
+// -------------------------------------------------------------------------- //
+
+// Compare performance of default Tensorflow convolution kernels (Eigen) with
+// MKL kernels on CPU.
+
+// Before running these benchmarks configure OpenMP environment variables:
+//   export KMP_BLOCKTIME=0
+//   export OMP_NUM_THREADS=${num_threads}
+
+namespace tensorflow {
+
+struct Conv2DDimensions {
+  Conv2DDimensions(int n, int h, int w, int c, int fc, int fh, int fw)
+      : input_batches(n),
+        input_height(h),
+        input_width(w),
+        input_depth(c),
+        filter_count(fc),
+        filter_height(fh),
+        filter_width(fw) {}
+
+  int input_batches;
+  int input_height;
+  int input_width;
+  int input_depth;
+  int filter_count;
+  int filter_height;
+  int filter_width;
+};
+
+static Tensor GetRandomTensor(const TensorShape& shape) {
+  Tensor tensor(DT_FLOAT, TensorShape(shape));
+  tensor.flat<float>() = tensor.flat<float>().setRandom();
+  return tensor;
+}
+
+// Get a random Tensor for the Conv2D input.
+static Tensor GetRandomInputTensor(const Conv2DDimensions& dims) {
+  return GetRandomTensor({dims.input_batches, dims.input_height,
+                          dims.input_width, dims.input_depth});
+}
+
+// Get a random Tensor for the Conv2D filter.
+static Tensor GetRandomFilterTensor(const Conv2DDimensions& dims) {
+  return GetRandomTensor({dims.filter_height, dims.filter_width,
+                          dims.input_depth, dims.filter_count});
+}
+
+// Get a random Tensor for the Conv2D output (assuming SAME padding).
+static Tensor GetRandomOutputTensor(const Conv2DDimensions& dims) {
+  return GetRandomTensor({dims.input_batches, dims.input_height,
+                          dims.input_width, dims.filter_count});
+}
+
+// Get a Tensor encoding Conv2D input shape.
+static Tensor GetInputSizesTensor(const Conv2DDimensions& dims) {
+  return test::AsTensor<int32>({dims.input_batches, dims.input_height,
+                                dims.input_width, dims.input_depth});
+}
+
+// Get a Tensor encoding Conv2D filter shape.
+static Tensor GetFilterSizesTensor(const Conv2DDimensions& dims) {
+  return test::AsTensor<int32>({dims.filter_height, dims.filter_width,
+                                dims.input_depth, dims.filter_count});
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Tensor NonMklTensor() {
+  MklDnnShape non_mkl_shape;
+  non_mkl_shape.SetMklTensor(false);
+
+  auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize());
+  Tensor tensor(DT_UINT8, {size});
+
+  non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
+                                     size * sizeof(uint8));
+  return tensor;
+}
+#endif
+
+static Graph* DefaultConv2D(const Conv2DDimensions& dims) {
+  auto* graph = new Graph(OpRegistry::Global());
+
+  Tensor input_t = GetRandomInputTensor(dims);
+  Tensor filter_t = GetRandomFilterTensor(dims);
+
+  Node* input = test::graph::Constant(graph, input_t, "input");
+  Node* filter = test::graph::Constant(graph, filter_t, "filter");
+
+  Node* conv2d;
+  TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d"), "Conv2D")
+                  .Input(input)
+                  .Input(filter)
+                  .Attr("T", DT_FLOAT)
+                  .Attr("strides", {1, 1, 1, 1})
+                  .Attr("padding", "SAME")
+                  .Finalize(graph, &conv2d));
+
+  return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2D(const Conv2DDimensions& dims) {
+  auto* graph = new Graph(OpRegistry::Global());
+
+  Tensor input_t = GetRandomInputTensor(dims);
+  Tensor filter_t = GetRandomFilterTensor(dims);
+
+  Node* input = test::graph::Constant(graph, input_t, "input");
+  Node* filter = test::graph::Constant(graph, filter_t, "filter");
+
+  Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+  Node* conv2d;
+  TF_CHECK_OK(NodeBuilder(graph->NewName("mkl_conv_2d"), "_MklConv2D")
+                  .Input(input)
+                  .Input(filter)
+                  .Input(not_mkl_shape)
+                  .Input(not_mkl_shape)
+                  .Attr("T", DT_FLOAT)
+                  .Attr("strides", {1, 1, 1, 1})
+                  .Attr("padding", "SAME")
+                  .Attr("_kernel", "MklOp")
+                  .Finalize(graph, &conv2d));
+
+  return graph;
+}
+#endif
+
+static Graph* DefaultConv2DBwdInput(const Conv2DDimensions& dims) {
+  auto* graph = new Graph(OpRegistry::Global());
+
+  Tensor input_sizes_t = GetInputSizesTensor(dims);
+  Tensor filter_t = GetRandomFilterTensor(dims);
+  Tensor out_backprop_t = GetRandomOutputTensor(dims);  // assuming SAME padding
+
+  Node* input_sizes =
+      test::graph::Constant(graph, input_sizes_t, "input_sizes");
+  Node* filter = test::graph::Constant(graph, filter_t, "filter");
+  Node* out_backprop =
+      test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+  Node* conv2d_bwd_input;
+  TF_CHECK_OK(
+      NodeBuilder(graph->NewName("conv_2d_bwd_input"), "Conv2DBackpropInput")
+          .Input(input_sizes)
+          .Input(filter)
+          .Input(out_backprop)
+          .Attr("T", DT_FLOAT)
+          .Attr("strides", {1, 1, 1, 1})
+          .Attr("padding", "SAME")
+          .Finalize(graph, &conv2d_bwd_input));
+
+  return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2DBwdInput(const Conv2DDimensions& dims) {
+  auto* graph = new Graph(OpRegistry::Global());
+
+  Tensor input_sizes_t = GetInputSizesTensor(dims);
+  Tensor filter_t = GetRandomFilterTensor(dims);
+  Tensor out_backprop_t = GetRandomOutputTensor(dims);  // assuming SAME padding
+
+  Node* input_sizes =
+      test::graph::Constant(graph, input_sizes_t, "input_sizes");
+  Node* filter = test::graph::Constant(graph, filter_t, "filter");
+  Node* out_backprop =
+      test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+  Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+  Node* conv2d_bwd_input;
+  TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_input"),
+                          "_MklConv2DBackpropInput")
+                  .Input(input_sizes)
+                  .Input(filter)
+                  .Input(out_backprop)
+                  .Input(not_mkl_shape)
+                  .Input(not_mkl_shape)
+                  .Input(not_mkl_shape)
+                  .Attr("T", DT_FLOAT)
+                  .Attr("strides", {1, 1, 1, 1})
+                  .Attr("padding", "SAME")
+                  .Attr("_kernel", "MklOp")
+                  .Finalize(graph, &conv2d_bwd_input));
+
+  return graph;
+}
+#endif
+
+static Graph* DefaultConv2DBwdFilter(const Conv2DDimensions& dims) {
+  auto* graph = new Graph(OpRegistry::Global());
+
+  Tensor input_t = GetRandomInputTensor(dims);
+  Tensor filter_sizes_t = GetFilterSizesTensor(dims);
+  Tensor filter_t = GetRandomFilterTensor(dims);
+  Tensor out_backprop_t = GetRandomOutputTensor(dims);  // assuming SAME padding
+
+  Node* input = test::graph::Constant(graph, input_t, "input");
+  Node* filter_sizes =
+      test::graph::Constant(graph, filter_sizes_t, "filter_sizes");
+  Node* out_backprop =
+      test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+  Node* conv2d_bwd_filter;
+  TF_CHECK_OK(
+      NodeBuilder(graph->NewName("conv_2d_bwd_filter"), "Conv2DBackpropFilter")
+          .Input(input)
+          .Input(filter_sizes)
+          .Input(out_backprop)
+          .Attr("T", DT_FLOAT)
+          .Attr("strides", {1, 1, 1, 1})
+          .Attr("padding", "SAME")
+          .Finalize(graph, &conv2d_bwd_filter));
+
+  return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2DBwdFilter(const Conv2DDimensions& dims) {
+  Graph* graph = new Graph(OpRegistry::Global());
+
+  Tensor input_t = GetRandomInputTensor(dims);
+  Tensor filter_sizes_t = GetFilterSizesTensor(dims);
+  Tensor filter_t = GetRandomFilterTensor(dims);
+  Tensor out_backprop_t = GetRandomOutputTensor(dims);  // assuming SAME padding
+
+  Node* input = test::graph::Constant(graph, input_t, "input");
+  Node* filter_sizes =
+      test::graph::Constant(graph, filter_sizes_t, "filter_sizes");
+  Node* out_backprop =
+      test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+  Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+  Node* conv2d_bwd_filter;
+  TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_filter"),
+                          "_MklConv2DBackpropFilter")
+                  .Input(input)
+                  .Input(filter_sizes)
+                  .Input(out_backprop)
+                  .Input(not_mkl_shape)
+                  .Input(not_mkl_shape)
+                  .Input(not_mkl_shape)
+                  .Attr("T", DT_FLOAT)
+                  .Attr("strides", {1, 1, 1, 1})
+                  .Attr("padding", "SAME")
+                  .Attr("_kernel", "MklOp")
+                  .Finalize(graph, &conv2d_bwd_filter));
+
+  return graph;
+}
+#endif
+
+// Macro arguments names: --------------------------------------------------- //
+//    N: batch size
+//    H: height
+//    W: width
+//    C: channels
+//   FC: filter count
+//   FH: filter height
+//   FW: filter width
+
+#define BM_CONCAT(a, b) a##b
+
+#define BM_NAME(p, type, N, H, W, C, FC, FH, FW) \
+  BM_CONCAT(BM_##p##_##type##_in_##N##_##H##_##W##_##C, _f_##FC##_##FH##_##FW)
+
+// Flops computation in these benchmarks are the same as in
+// eigen_benchmark_cpu_test.cc.
+
+#define BM_Conv2DT(kind, N, H, W, C, FC, FH, FW, type, LABEL)            \
+  static void BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH,           \
+                      FW)(int iters) {                                   \
+    testing::SetLabel(LABEL);                                            \
+                                                                         \
+    int64 num_computed_elements = (N) * (H) * (W) * (FC);                \
+    int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW));  \
+    testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+                                                                         \
+    Conv2DDimensions dims(N, H, W, C, FC, FW, FH);                       \
+    test::Benchmark(#type, BM_CONCAT(kind, Conv2D)(dims)).Run(iters);    \
+  }                                                                      \
+  BENCHMARK(BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL)      \
+  BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+  BM_Conv2DT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \
+  BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+#define BM_Conv2DBwdInputT(kind, N, H, W, C, FC, FH, FW, type, LABEL)         \
+  static void BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH,        \
+                      FW)(int iters) {                                        \
+    testing::SetLabel(LABEL);                                                 \
+                                                                              \
+    int64 num_computed_elements = (N) * (H) * (W) * (C);                      \
+    int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW));       \
+    testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter);      \
+                                                                              \
+    Conv2DDimensions dims(N, H, W, C, FC, FW, FH);                            \
+    test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdInput)(dims)).Run(iters); \
+  }                                                                           \
+  BENCHMARK(BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL)      \
+  BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+  BM_Conv2DBwdInputT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \
+  BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+#define BM_Conv2DBwdFilterT(kind, N, H, W, C, FC, FH, FW, type, LABEL)         \
+  static void BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH,        \
+                      FW)(int iters) {                                         \
+    testing::SetLabel(LABEL);                                                  \
+                                                                               \
+    int64 num_computed_elements = (FH) * (FW) * (C) * (FC);                    \
+    int64 flops_per_iter = num_computed_elements * ((N) * (H) * (W));          \
+    testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter);       \
+                                                                               \
+    Conv2DDimensions dims(N, H, W, C, FC, FW, FH);                             \
+    test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdFilter)(dims)).Run(iters); \
+  }                                                                            \
+  BENCHMARK(BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL)      \
+  BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+  BM_Conv2DBwdFilterT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \
+  BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+// ImageNet Convolutions ---------------------------------------------------- //
+
+BM_Conv2D(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2D(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2D(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2D(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2D(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2D(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2D(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+BM_Conv2DBwdInput(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2DBwdInput(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2DBwdInput(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2DBwdInput(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2DBwdInput(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2DBwdInput(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2DBwdInput(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+BM_Conv2DBwdFilter(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2DBwdFilter(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2DBwdFilter(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2DBwdFilter(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2DBwdFilter(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2DBwdFilter(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2DBwdFilter(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc
index 7a64788..82dfece 100644
--- a/tensorflow/core/kernels/multinomial_op.cc
+++ b/tensorflow/core/kernels/multinomial_op.cc
@@ -75,7 +75,7 @@
       // lambda.  Since we want to let each worker have its own copy, we pass
       // "gen" by reference and explicitly do a copy assignment here.
       random::PhiloxRandom gen_copy = gen;
-      // Skip takes units of 128 bytes.  +3 is so rounding doesn't lead to
+      // Skip takes units of 128 bits.  +3 is so rounding doesn't lead to
       // us using the same state in different batches.
       gen_copy.Skip(start_row * (num_samples + 3) / 4);
       random::SimplePhilox simple_philox(&gen_copy);
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
index 5fb1c92..272aa3b 100644
--- a/tensorflow/core/kernels/queue_base.h
+++ b/tensorflow/core/kernels/queue_base.h
@@ -19,6 +19,7 @@
 #include <deque>
 #include <vector>
 
+#include "absl/base/macros.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/queue_interface.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -82,6 +83,9 @@
   // NOTE(mrry): This method is deprecated. Use
   // `tensorflow::batch_util::CopySliceToElement()` defined in
   // "./batch_util.h" instead.
+  ABSL_DEPRECATED(
+      "Use `tensorflow::batch_util::CopySliceToElement()` defined in "
+      "\"./batch_util.h\" instead.")
   static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
                                    int64 index);
 
diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
index 88b3c2a..bb8254e 100644
--- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
+++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
@@ -21,11 +21,11 @@
 #define EIGEN_USE_GPU
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_segmented_reduce.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
-#include "external/cub_archive/cub/warp/warp_reduce.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_segmented_reduce.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/warp/warp_reduce.cuh"
 #include "cuda/include/cuComplex.h"
 #include "tensorflow/core/kernels/reduction_ops.h"
 #include "tensorflow/core/lib/core/bits.h"
diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc
index 9cf953f..8bfa44b 100644
--- a/tensorflow/core/kernels/reduction_ops_max.cc
+++ b/tensorflow/core/kernels/reduction_ops_max.cc
@@ -50,6 +50,8 @@
           .TypeConstraint<int64>("Tidx")                                       \
           .HostMemory("reduction_indices"),                                    \
       ReductionOp<GPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
+
+REGISTER_GPU_KERNELS(Eigen::half);
 REGISTER_GPU_KERNELS(float);
 REGISTER_GPU_KERNELS(double);
 REGISTER_GPU_KERNELS(int64);
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
index e4ca89e..5318d8c 100644
--- a/tensorflow/core/kernels/reduction_ops_sum.cc
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -76,15 +76,7 @@
         .HostMemory("output")
         .HostMemory("reduction_indices"),
     ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>);
-REGISTER_KERNEL_BUILDER(
-    Name("Sum")
-        .Device(DEVICE_GPU)
-        .TypeConstraint<int64>("T")
-        .TypeConstraint<int32>("Tidx")
-        .HostMemory("input")
-        .HostMemory("output")
-        .HostMemory("reduction_indices"),
-    ReductionOp<CPUDevice, int64, int32, Eigen::internal::SumReducer<int64>>);
+
 #endif
 
 #ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/searchsorted_op.cc b/tensorflow/core/kernels/searchsorted_op.cc
new file mode 100644
index 0000000..dc627ac
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op.cc
@@ -0,0 +1,249 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/searchsorted_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+template <typename T, typename OutType>
+struct UpperBoundFunctor<CPUDevice, T, OutType> {
+  static Status Compute(OpKernelContext* context,
+                        const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+                        const typename TTypes<T, 1>::ConstTensor& values,
+                        int batch_size, int num_inputs, int num_values,
+                        typename TTypes<OutType, 1>::Tensor* output) {
+    // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+    for (int b = 0; b < batch_size; ++b) {
+      const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+      OutType* output_ptr = output->data() + b * num_values;
+      for (int i = 0; i < num_values; ++i) {
+        output_ptr[i] =
+            std::upper_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+                             values(i + b * num_values)) -
+            sorted_inputs_ptr;
+      }
+    }
+
+    return Status::OK();
+  }
+};
+
+template <typename T, typename OutType>
+struct LowerBoundFunctor<CPUDevice, T, OutType> {
+  static Status Compute(OpKernelContext* context,
+                        const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+                        const typename TTypes<T, 1>::ConstTensor& values,
+                        int batch_size, int num_inputs, int num_values,
+                        typename TTypes<OutType, 1>::Tensor* output) {
+    // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+    for (int b = 0; b < batch_size; ++b) {
+      const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+      OutType* output_ptr = output->data() + b * num_values;
+      for (int i = 0; i < num_values; ++i) {
+        output_ptr[i] =
+            std::lower_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+                             values(i + b * num_values)) -
+            sorted_inputs_ptr;
+      }
+    }
+
+    return Status::OK();
+  }
+};
+}  // namespace functor
+
+template <typename Device, typename T, typename OutType>
+class UpperBoundOp : public OpKernel {
+ public:
+  explicit UpperBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor& sorted_inputs_t = ctx->input(0);
+    const Tensor& values_t = ctx->input(1);
+
+    // must have same batch dim_size for both
+    OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+                Status(error::INVALID_ARGUMENT,
+                       "Leading dim_size of both tensors must match."));
+
+    // this is required because we do indexing in int32 on the GPU
+    OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+                Status(error::INVALID_ARGUMENT,
+                       "values tensor size must less than INT_MAX"));
+
+    Tensor* output_t;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+    if (output_t->dtype() == DT_INT32) {
+      OP_REQUIRES(ctx,
+                  FastBoundsCheck(sorted_inputs_t.dim_size(1),
+                                  std::numeric_limits<int>::max()),
+                  errors::InvalidArgument("trailing dim_size must less than "
+                                          "INT_MAX for int32 output type, was ",
+                                          sorted_inputs_t.dim_size(1)));
+    }
+
+    auto output = output_t->template flat<OutType>();
+    const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+    const auto values = values_t.template flat<T>();
+    OP_REQUIRES_OK(
+        ctx, functor::UpperBoundFunctor<Device, T, OutType>::Compute(
+                 ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+                 sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+  }
+};
+
+template <typename Device, typename T, typename OutType>
+class LowerBoundOp : public OpKernel {
+ public:
+  explicit LowerBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor& sorted_inputs_t = ctx->input(0);
+    const Tensor& values_t = ctx->input(1);
+
+    // must have same batch dim_size for both
+    OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+                Status(error::INVALID_ARGUMENT,
+                       "Leading dim_size of both tensors must match."));
+
+    // this is required because we do indexing in int32 on the GPU
+    OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+                Status(error::INVALID_ARGUMENT,
+                       "values tensor size must less than INT_MAX"));
+
+    Tensor* output_t;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+    if (output_t->dtype() == DT_INT32) {
+      OP_REQUIRES(ctx,
+                  FastBoundsCheck(sorted_inputs_t.dim_size(1),
+                                  std::numeric_limits<int>::max()),
+                  errors::InvalidArgument("trailing dim_size must less than "
+                                          "INT_MAX for int32 output type, was ",
+                                          sorted_inputs_t.dim_size(1)));
+    }
+
+    auto output = output_t->template flat<OutType>();
+    const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+    const auto values = values_t.template flat<T>();
+    OP_REQUIRES_OK(
+        ctx, functor::LowerBoundFunctor<Device, T, OutType>::Compute(
+                 ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+                 sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+  }
+};
+
+#define REGISTER_KERNELS(type)                                    \
+  REGISTER_KERNEL_BUILDER(Name("UpperBound")                      \
+                              .Device(DEVICE_CPU)                 \
+                              .TypeConstraint<type>("T")          \
+                              .TypeConstraint<int32>("out_type"), \
+                          UpperBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type)                                    \
+  REGISTER_KERNEL_BUILDER(Name("UpperBound")                      \
+                              .Device(DEVICE_CPU)                 \
+                              .TypeConstraint<type>("T")          \
+                              .TypeConstraint<int64>("out_type"), \
+                          UpperBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type)                                    \
+  REGISTER_KERNEL_BUILDER(Name("UpperBound")                      \
+                              .Device(DEVICE_GPU)                 \
+                              .TypeConstraint<type>("T")          \
+                              .TypeConstraint<int32>("out_type"), \
+                          UpperBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type)                                    \
+  REGISTER_KERNEL_BUILDER(Name("UpperBound")                      \
+                              .Device(DEVICE_GPU)                 \
+                              .TypeConstraint<type>("T")          \
+                              .TypeConstraint<int64>("out_type"), \
+                          UpperBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif  // GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type)                                    \
+  REGISTER_KERNEL_BUILDER(Name("LowerBound")                      \
+                              .Device(DEVICE_CPU)                 \
+                              .TypeConstraint<type>("T")          \
+                              .TypeConstraint<int32>("out_type"), \
+                          LowerBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type)                                    \
+  REGISTER_KERNEL_BUILDER(Name("LowerBound")                      \
+                              .Device(DEVICE_CPU)                 \
+                              .TypeConstraint<type>("T")          \
+                              .TypeConstraint<int64>("out_type"), \
+                          LowerBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type)                                    \
+  REGISTER_KERNEL_BUILDER(Name("LowerBound")                      \
+                              .Device(DEVICE_GPU)                 \
+                              .TypeConstraint<type>("T")          \
+                              .TypeConstraint<int32>("out_type"), \
+                          LowerBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type)                                    \
+  REGISTER_KERNEL_BUILDER(Name("LowerBound")                      \
+                              .Device(DEVICE_GPU)                 \
+                              .TypeConstraint<type>("T")          \
+                              .TypeConstraint<int64>("out_type"), \
+                          LowerBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif  // GOOGLE_CUDA
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/searchsorted_op.h b/tensorflow/core/kernels/searchsorted_op.h
new file mode 100644
index 0000000..f075bf0
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op.h
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T, typename OutType>
+struct UpperBoundFunctor {
+  // Searches for values in sorted_inputs and returns the greatest possible
+  // index where they maintain sorted order.
+  static Status Compute(OpKernelContext* context,
+                        const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+                        const typename TTypes<T, 1>::ConstTensor& values,
+                        int batch_size, int num_inputs, int num_values,
+                        typename TTypes<OutType, 1>::Tensor* output);
+};
+
+template <typename Device, typename T, typename OutType>
+struct LowerBoundFunctor {
+  // Searches for values in sorted_inputs and returns the lowest possible
+  // index where they maintain sorted order.
+  static Status Compute(OpKernelContext* context,
+                        const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+                        const typename TTypes<T, 1>::ConstTensor& values,
+                        int batch_size, int num_inputs, int num_values,
+                        typename TTypes<OutType, 1>::Tensor* output);
+};
+}  // namespace functor
+
+}  // end namespace tensorflow
+#endif  // TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
diff --git a/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
new file mode 100644
index 0000000..263b5bf
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
@@ -0,0 +1,126 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/searchsorted_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace {
+template <typename T, typename OutType>
+__global__ void UpperBoundKernel(const T* sorted_inputs, int batch_size,
+                                 int sorted_inputs_size, int values_size,
+                                 const T* values, OutType* outputs) {
+  CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
+    int bid = work_unit_id / values_size;
+    T value = values[work_unit_id];
+    outputs[work_unit_id] = cuda_helper::upper_bound<T, OutType>(
+        sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value);
+  }
+}
+
+template <typename T, typename OutType>
+__global__ void LowerBoundKernel(const T* sorted_inputs, int batch_size,
+                                 int sorted_inputs_size, int values_size,
+                                 const T* values, OutType* outputs) {
+  CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
+    int bid = work_unit_id / values_size;
+    T value = values[work_unit_id];
+    outputs[work_unit_id] = cuda_helper::lower_bound<T, OutType>(
+        sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value);
+  }
+}
+}  // namespace
+
+namespace functor {
+template <typename T, typename OutType>
+struct UpperBoundFunctor<GPUDevice, T, OutType> {
+  static Status Compute(OpKernelContext* context,
+                        const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+                        const typename TTypes<T, 1>::ConstTensor& values,
+                        int batch_size, int num_inputs, int num_values,
+                        typename TTypes<OutType, 1>::Tensor* output) {
+    const cudaStream_t& stream = GetCudaStream(context);
+    CudaLaunchConfig config =
+        GetCudaLaunchConfig(values.size(), context->eigen_gpu_device());
+
+    UpperBoundKernel<T>
+        <<<config.block_count, config.thread_per_block, 0, stream>>>(
+            sorted_inputs.data(), batch_size, num_inputs, num_values,
+            values.data(), output->data());
+
+    return Status::OK();
+  }
+};
+
+template <typename T, typename OutType>
+struct LowerBoundFunctor<GPUDevice, T, OutType> {
+  static Status Compute(OpKernelContext* context,
+                        const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+                        const typename TTypes<T, 1>::ConstTensor& values,
+                        int batch_size, int num_inputs, int num_values,
+                        typename TTypes<OutType, 1>::Tensor* output) {
+    const cudaStream_t& stream = GetCudaStream(context);
+    CudaLaunchConfig config =
+        GetCudaLaunchConfig(values.size(), context->eigen_gpu_device());
+
+    LowerBoundKernel<T>
+        <<<config.block_count, config.thread_per_block, 0, stream>>>(
+            sorted_inputs.data(), batch_size, num_inputs, num_values,
+            values.data(), output->data());
+
+    return Status::OK();
+  }
+};
+}  // namespace functor
+
+#define REGISTER_GPU_SPEC(type) \
+  template struct functor::UpperBoundFunctor<GPUDevice, type, int32>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+  template struct functor::UpperBoundFunctor<GPUDevice, type, int64>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+  template struct functor::LowerBoundFunctor<GPUDevice, type, int32>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+  template struct functor::LowerBoundFunctor<GPUDevice, type, int64>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+}  // namespace tensorflow
+
+#endif  // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/string_format_op.cc b/tensorflow/core/kernels/string_format_op.cc
new file mode 100644
index 0000000..e4a1887
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iostream>
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class StringFormatOp : public OpKernel {
+ public:
+  explicit StringFormatOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    string template_;
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("template", &template_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("placeholder", &placeholder_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
+
+    split_template_ = absl::StrSplit(template_, placeholder_);
+    int64 num_placeholders = split_template_.size() - 1;
+    OP_REQUIRES(ctx, ctx->num_inputs() == num_placeholders,
+                errors::InvalidArgument(strings::StrCat(
+                    "num placeholders in template and num inputs must match: ",
+                    num_placeholders, " vs. ", ctx->num_inputs())));
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    Tensor* formatted_string = nullptr;
+    OP_REQUIRES_OK(ctx,
+                   ctx->allocate_output(0, TensorShape({}), &formatted_string));
+
+    string msg;
+    strings::StrAppend(&msg, split_template_[0].c_str());
+    for (int i = 0; i < ctx->num_inputs(); ++i) {
+      strings::StrAppend(&msg, ctx->input(i).SummarizeValue(summarize_, true));
+      strings::StrAppend(&msg, split_template_[i + 1].c_str());
+    }
+
+    formatted_string->scalar<string>()() = msg;
+  }
+
+ private:
+  int32 summarize_ = 0;
+  string placeholder_;
+  std::vector<std::string> split_template_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StringFormat").Device(DEVICE_CPU),
+                        StringFormatOp);
+
+}  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_format_op_test.cc b/tensorflow/core/kernels/string_format_op_test.cc
new file mode 100644
index 0000000..13130a5
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op_test.cc
@@ -0,0 +1,66 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace {
+
+class StringFormatGraphTest : public OpsTestBase {
+ protected:
+  Status Init(int num_inputs, DataType input_type,
+              const string& template_ = "%s", const string& placeholder = "%s",
+              int summarize = 3) {
+    TF_CHECK_OK(NodeDefBuilder("op", "StringFormat")
+                    .Input(FakeInput(num_inputs, input_type))
+                    .Attr("template", template_)
+                    .Attr("placeholder", placeholder)
+                    .Attr("summarize", summarize)
+                    .Finalize(node_def()));
+    return InitOp();
+  }
+};
+
+TEST_F(StringFormatGraphTest, Int32Success_7) {
+  TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s"));
+
+  AddInputFromArray<int32>(TensorShape({7}), {1, 2, 3, 4, 5, 6, 7});
+  TF_ASSERT_OK(RunOpKernel());
+  Tensor expected(allocator(), DT_STRING, TensorShape({}));
+  test::FillValues<string>(&expected, {"First tensor: [1 2 3 ... 5 6 7]"});
+  test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(StringFormatGraphTest, Int32Success_3_3) {
+  TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s", "%s", 1));
+
+  AddInputFromArray<int32>(TensorShape({3, 3}), {1, 2, 3, 4, 5, 6, 7, 8, 9});
+  TF_ASSERT_OK(RunOpKernel());
+  Tensor expected(allocator(), DT_STRING, TensorShape({}));
+  test::FillValues<string>(&expected, {"First tensor: [[1 ... 3]\n ..."
+                                       "\n [7 ... 9]]"});
+  test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+}  // end namespace
+}  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc
index ca296d5..2fbe1fe 100644
--- a/tensorflow/core/kernels/topk_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc
@@ -20,9 +20,9 @@
 #include <cmath>
 #include <vector>
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_segmented_radix_sort.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc
index 62e814f..8d839ba 100644
--- a/tensorflow/core/kernels/unravel_index_op.cc
+++ b/tensorflow/core/kernels/unravel_index_op.cc
@@ -97,10 +97,12 @@
 
       auto output = output_tensor->matrix<Tidx>();
 
-      Eigen::array<int64, 2> reshape{{dims_tensor.NumElements(), 1}};
-      Eigen::array<int64, 2> bcast({1, indices_tensor.NumElements()});
-      Eigen::array<int64, 2> indices_reshape{{1, indices_tensor.NumElements()}};
-      Eigen::array<int64, 2> indices_bcast({dims_tensor.NumElements(), 1});
+      Eigen::array<Eigen::Index, 2> reshape{{dims_tensor.NumElements(), 1}};
+      Eigen::array<Eigen::Index, 2> bcast({1, indices_tensor.NumElements()});
+      Eigen::array<Eigen::Index, 2> indices_reshape{
+          {1, indices_tensor.NumElements()}};
+      Eigen::array<Eigen::Index, 2> indices_bcast(
+          {dims_tensor.NumElements(), 1});
 
       output = indices_tensor.vec<Tidx>()
                    .reshape(indices_reshape)
diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h
index 8879d9d..2255597 100644
--- a/tensorflow/core/kernels/where_op_gpu.cu.h
+++ b/tensorflow/core/kernels/where_op_gpu.cu.h
@@ -21,10 +21,10 @@
 #define EIGEN_USE_GPU
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_select.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_select.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/tensor_types.h"
 #include "tensorflow/core/kernels/bounds_check.h"
diff --git a/tensorflow/core/lib/core/status.h b/tensorflow/core/lib/core/status.h
index 49f74ff..eb0ff55 100644
--- a/tensorflow/core/lib/core/status.h
+++ b/tensorflow/core/lib/core/status.h
@@ -24,6 +24,7 @@
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
index e7b17c9..6edff13 100644
--- a/tensorflow/core/lib/core/stringpiece.h
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -26,13 +26,7 @@
 #ifndef TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
 #define TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
 
-#include <assert.h>
-#include <stddef.h>
-#include <string.h>
-#include <iosfwd>
-#include <string>
 #include "absl/strings/string_view.h"
-#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc
index 99684ae..9ccd911 100644
--- a/tensorflow/core/lib/core/threadpool.cc
+++ b/tensorflow/core/lib/core/threadpool.cc
@@ -17,6 +17,7 @@
 
 #define EIGEN_USE_THREADS
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/lib/core/blocking_counter.h"
 #include "tensorflow/core/platform/context.h"
 #include "tensorflow/core/platform/denormal.h"
 #include "tensorflow/core/platform/logging.h"
@@ -120,6 +121,54 @@
   impl_->Schedule(std::move(fn));
 }
 
+int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
+    const int64 block_size, const int64 total) {
+  if (block_size <= 0 || total <= 1 || total <= block_size ||
+      NumThreads() == 1) {
+    return 1;
+  }
+  return (total + block_size - 1) / block_size;
+}
+
+// This functionality is similar to parallelFor, except that reasoning about
+// the number of shards used is significantly easier.
+void ThreadPool::TransformRangeConcurrently(
+    const int64 block_size, const int64 total,
+    const std::function<void(int64, int64)>& fn) {
+  const int num_shards_used =
+      NumShardsUsedByTransformRangeConcurrently(block_size, total);
+  if (num_shards_used == 1) {
+    fn(0, total);
+    return;
+  }
+
+  // Adapted from Eigen's parallelFor implementation.
+  BlockingCounter counter(num_shards_used);
+  std::function<void(int64, int64)> handle_range =
+      [=, &handle_range, &counter, &fn](int64 first, int64 last) {
+        while (last - first > block_size) {
+          // Find something near the midpoint which is a multiple of block size.
+          const int64 mid = first + ((last - first) / 2 + block_size - 1) /
+                                        block_size * block_size;
+          Schedule([=, &handle_range]() { handle_range(mid, last); });
+          last = mid;
+        }
+        // Single block or less, execute directly.
+        fn(first, last);
+        counter.DecrementCount();  // The shard is done.
+      };
+  if (num_shards_used <= NumThreads()) {
+    // Avoid a thread hop by running the root of the tree and one block on the
+    // main thread.
+    handle_range(0, total);
+  } else {
+    // Execute the root in the thread pool to avoid running work on more than
+    // numThreads() threads.
+    Schedule([=, &handle_range]() { handle_range(0, total); });
+  }
+  counter.Wait();
+}
+
 void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
                              std::function<void(int64, int64)> fn) {
   impl_->ParallelFor(total, cost_per_unit, std::move(fn));
diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h
index 74df7c8..e14ad7a 100644
--- a/tensorflow/core/lib/core/threadpool.h
+++ b/tensorflow/core/lib/core/threadpool.h
@@ -59,6 +59,20 @@
   // Schedules fn() for execution in the pool of threads.
   void Schedule(std::function<void()> fn);
 
+  // Requires 0 < block_size <= total.
+  // Spawns k threads and calls fn(i*block_size, (i+1)*block_size) from the
+  // ith thread (i>=0). When (i+1)*block_size > total, fn(i*block_size, total)
+  // is called instead. k = NumShardsUsedByTransformRangeConcurrently(...).
+  // Note that when there aren't enough threads in the pool to achieve full
+  // parallelism, function calls will be automatically queued.
+  void TransformRangeConcurrently(const int64 block_size, const int64 total,
+                                  const std::function<void(int64, int64)>& fn);
+
+  // Returns the number of threads spawned by calling TransformRangeConcurrently
+  // with these parameters.
+  int NumShardsUsedByTransformRangeConcurrently(const int64 block_size,
+                                                const int64 total);
+
   // ParallelFor shards the "total" units of work assuming each unit of work
   // having roughly "cost_per_unit" cost, in cycles. Each unit of work is
   // indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work
diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc
index 320f3eb..db996b7 100644
--- a/tensorflow/core/lib/core/threadpool_test.cc
+++ b/tensorflow/core/lib/core/threadpool_test.cc
@@ -61,6 +61,67 @@
   }
 }
 
+void RunSharding(int64 block_size, int64 total, ThreadPool* threads) {
+  mutex mu;
+  int64 num_shards = 0;
+  int64 num_done_work = 0;
+  std::vector<bool> work(total, false);
+  threads->TransformRangeConcurrently(
+      block_size, total,
+      [=, &mu, &num_shards, &num_done_work, &work](int64 start, int64 end) {
+        VLOG(1) << "Shard [" << start << "," << end << ")";
+        EXPECT_GE(start, 0);
+        EXPECT_LE(end, total);
+        mutex_lock l(mu);
+        ++num_shards;
+        for (; start < end; ++start) {
+          EXPECT_FALSE(work[start]);  // No duplicate
+          ++num_done_work;
+          work[start] = true;
+        }
+      });
+  LOG(INFO) << block_size << " " << total;
+  const int64 num_workers = (total + block_size - 1) / block_size;
+  EXPECT_EQ(num_done_work, total);
+  if (num_workers < threads->NumThreads()) {
+    // If the intention is to limit the parallelism explicitly, we'd
+    // better honor it. Ideally, even if per_thread_max_parallelism >
+    // num_workers, we should expect that Shard() implementation do
+    // not over-shard. Unfortunately, ThreadPoolDevice::parallelFor
+    // tends to over-shard.
+    EXPECT_LE(num_shards, 1 + num_workers);
+  }
+}
+
+// Adapted from work_sharder_test.cc
+TEST(SparseUtilsTest, TransformRangeConcurrently) {
+  ThreadPool threads(Env::Default(), "test", 16);
+  for (auto block_size : {1, 7, 10, 64, 100, 256, 1000, 9999}) {
+    for (auto diff : {0, 1, 11, 102, 1003, 10005, 1000007}) {
+      const int64 total = block_size + diff;
+      RunSharding(block_size, total, &threads);
+    }
+  }
+}
+
+TEST(SparseUtilsTest, NumShardsUsedByTransformRangeConcurrently) {
+  ThreadPool threads(Env::Default(), "test", 16);
+  EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently(
+                   3 /* block_size */, 3 /* total */));
+  EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+                   3 /* block_size */, 4 /* total */));
+  EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+                   3 /* block_size */, 5 /* total */));
+  EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+                   3 /* block_size */, 6 /* total */));
+  EXPECT_EQ(3, threads.NumShardsUsedByTransformRangeConcurrently(
+                   3 /* block_size */, 7 /* total */));
+  EXPECT_EQ(7, threads.NumShardsUsedByTransformRangeConcurrently(
+                   1 /* block_size */, 7 /* total */));
+  EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently(
+                   0 /* block_size */, 7 /* total */));
+}
+
 TEST(ThreadPool, ParallelFor) {
   Context outer_context(ContextKind::kThread);
   // Make ParallelFor use as many threads as possible.
diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h
index e292768..117b6a0 100644
--- a/tensorflow/core/lib/io/block_builder.h
+++ b/tensorflow/core/lib/io/block_builder.h
@@ -20,6 +20,7 @@
 
 #include <stdint.h>
 #include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 namespace table {
diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h
index e3649fd..38fb0c5 100644
--- a/tensorflow/core/lib/io/path.h
+++ b/tensorflow/core/lib/io/path.h
@@ -17,6 +17,7 @@
 #define TENSORFLOW_CORE_LIB_IO_PATH_H_
 
 #include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 namespace io {
diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h
index c204d52..9e4e198 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.h
+++ b/tensorflow/core/lib/monitoring/collection_registry.h
@@ -28,6 +28,7 @@
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 namespace monitoring {
diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h
index 756e5c2..bc4365e 100644
--- a/tensorflow/core/lib/monitoring/metric_def.h
+++ b/tensorflow/core/lib/monitoring/metric_def.h
@@ -21,6 +21,7 @@
 
 #include "tensorflow/core/framework/summary.pb.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 namespace monitoring {
diff --git a/tensorflow/core/lib/png/png_io.h b/tensorflow/core/lib/png/png_io.h
index bb5d20f..c876c51 100644
--- a/tensorflow/core/lib/png/png_io.h
+++ b/tensorflow/core/lib/png/png_io.h
@@ -37,6 +37,7 @@
 
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/platform/png.h"
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 namespace png {
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 7dbb18a..c249506 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -2916,6 +2916,34 @@
 
 }  // namespace
 
+REGISTER_OP("UpperBound")
+    .Input("sorted_inputs: T")
+    .Input("values: T")
+    .Output("output: out_type")
+    .Attr("T: type")
+    .Attr("out_type: {int32, int64} = DT_INT32")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle unused_shape;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
+      c->set_output(0, c->input(1));
+      return Status::OK();
+    });
+
+REGISTER_OP("LowerBound")
+    .Input("sorted_inputs: T")
+    .Input("values: T")
+    .Output("output: out_type")
+    .Attr("T: type")
+    .Attr("out_type: {int32, int64} = DT_INT32")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle unused_shape;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
+      c->set_output(0, c->input(1));
+      return Status::OK();
+    });
+
 REGISTER_OP("ScatterNd")
     .Input("indices: Tindices")
     .Input("updates: T")
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 57c6bda..e30a111 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -29388,6 +29388,38 @@
   }
 }
 op {
+  name: "LowerBound"
+  input_arg {
+    name: "sorted_inputs"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
   name: "MakeIterator"
   input_arg {
     name: "dataset"
@@ -38880,6 +38912,30 @@
   is_stateful: true
 }
 op {
+  name: "PrintV2"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  attr {
+    name: "output_stream"
+    type: "string"
+    default_value {
+      s: "stderr"
+    }
+    allowed_values {
+      list {
+        s: "stdout"
+        s: "stderr"
+        s: "log(info)"
+        s: "log(warning)"
+        s: "log(error)"
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
   name: "PriorityQueue"
   output_arg {
     name: "handle"
@@ -70188,6 +70244,43 @@
   }
 }
 op {
+  name: "StringFormat"
+  input_arg {
+    name: "inputs"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "template"
+    type: "string"
+    default_value {
+      s: "%s"
+    }
+  }
+  attr {
+    name: "placeholder"
+    type: "string"
+    default_value {
+      s: "%s"
+    }
+  }
+  attr {
+    name: "summarize"
+    type: "int"
+    default_value {
+      i: 3
+    }
+  }
+}
+op {
   name: "StringJoin"
   input_arg {
     name: "inputs"
@@ -75267,6 +75360,38 @@
   is_stateful: true
 }
 op {
+  name: "UpperBound"
+  input_arg {
+    name: "sorted_inputs"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
   name: "VarHandleOp"
   output_arg {
     name: "resource"
@@ -75602,9 +75727,21 @@
     type: DT_VARIANT
   }
   input_arg {
-    name: "window_size"
+    name: "size"
     type: DT_INT64
   }
+  input_arg {
+    name: "shift"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "stride"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "drop_remainder"
+    type: DT_BOOL
+  }
   output_arg {
     name: "handle"
     type: DT_VARIANT
diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc
index f78f7a8..f84142c 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops.cc
@@ -37,7 +37,6 @@
 using shape_inference::InferenceContext;
 using shape_inference::ShapeHandle;
 
-
 REGISTER_OP("CudnnRNNParamsSize")
     .Input("num_layers: int32")
     .Input("num_units: int32")
@@ -52,11 +51,16 @@
     .Attr("seed2: int = 0")
     .Output("params_size: S")
     .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle unused;
+      // num_layers, num_units, and input_size should be scalars.
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+
       c->set_output(0, c->Vector(1));
       return Status::OK();
     });
 
-
 REGISTER_OP("CudnnRNN")
     .Input("input: T")
     .Input("input_h: T")
@@ -248,7 +252,6 @@
       return Status::OK();
     });
 
-
 REGISTER_OP("CudnnRNNCanonicalToParams")
     .Input("num_layers: int32")
     .Input("num_units: int32")
diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
index 2dd8675..13c3b93 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
@@ -26,7 +26,16 @@
 
 TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) {
   ShapeInferenceTestOp op("CudnnRNNParamsSize");
-  INFER_OK(op, "[1];[1];[1]", "[1]");
+  INFER_OK(op, "[];[];[]", "[1]");
+  INFER_OK(op, "?;[];[]", "[1]");
+  INFER_OK(op, "[];?;[]", "[1]");
+  INFER_OK(op, "[];[];?", "[1]");
+  INFER_OK(op, "[];?;?", "[1]");
+  INFER_OK(op, "?;?;?", "[1]");
+
+  INFER_ERROR("Shape must be rank 0 ", op, "[1,2];?;[]");
+  INFER_ERROR("Shape must be rank 0 ", op, "?;[2];[]");
+  INFER_ERROR("Shape must be rank 0 ", op, "?;?;[1]");
 }
 
 TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 7d9e7b2..4d3f272 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -396,14 +396,20 @@
 
 REGISTER_OP("WindowDataset")
     .Input("input_dataset: variant")
-    .Input("window_size: int64")
+    .Input("size: int64")
+    .Input("shift: int64")
+    .Input("stride: int64")
+    .Input("drop_remainder: bool")
     .Output("handle: variant")
     .Attr("output_types: list(type) >= 1")
     .Attr("output_shapes: list(shape) >= 1")
     .SetShapeFn([](shape_inference::InferenceContext* c) {
       shape_inference::ShapeHandle unused;
-      // batch_size should be a scalar.
+      // size, shift, stride, and drop_remainder should be scalars.
       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
       return shape_inference::ScalarShape(c);
     });
 
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
index 639d211..2034d36 100644
--- a/tensorflow/core/ops/logging_ops.cc
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -20,6 +20,8 @@
 
 namespace tensorflow {
 
+using shape_inference::InferenceContext;
+
 REGISTER_OP("Assert")
     .Input("condition: bool")
     .Input("data: T")
@@ -44,6 +46,23 @@
 
 WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Print");
 
+REGISTER_OP("PrintV2")
+    .Input("input: string")
+    .SetIsStateful()
+    .Attr(
+        "output_stream: {'stdout', 'stderr', 'log(info)', "
+        "'log(warning)', 'log(error)'} = 'stderr'")
+    .SetShapeFn([](InferenceContext* c) {
+      // Make sure that the input is a scalar.
+      if (c->Rank(c->input(0)) != 0) {
+        return errors::InvalidArgument("input must be a scalar, but has rank: ",
+                                       c->Rank(c->input(0)));
+      }
+      return Status::OK();
+    });
+
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("PrintV2");
+
 // ----------------------------------------------------------------------------
 // Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
 // inputs or outputs in various ways.
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 190f6aa..594edfd 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -14536,6 +14536,38 @@
   }
 }
 op {
+  name: "LowerBound"
+  input_arg {
+    name: "sorted_inputs"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
   name: "MakeIterator"
   input_arg {
     name: "dataset"
@@ -19521,6 +19553,30 @@
   is_stateful: true
 }
 op {
+  name: "PrintV2"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  attr {
+    name: "output_stream"
+    type: "string"
+    default_value {
+      s: "stderr"
+    }
+    allowed_values {
+      list {
+        s: "stdout"
+        s: "stderr"
+        s: "log(info)"
+        s: "log(warning)"
+        s: "log(error)"
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
   name: "PriorityQueue"
   output_arg {
     name: "handle"
@@ -32735,6 +32791,43 @@
   }
 }
 op {
+  name: "StringFormat"
+  input_arg {
+    name: "inputs"
+    type_list_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_STRING
+  }
+  attr {
+    name: "T"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "template"
+    type: "string"
+    default_value {
+      s: "%s"
+    }
+  }
+  attr {
+    name: "placeholder"
+    type: "string"
+    default_value {
+      s: "%s"
+    }
+  }
+  attr {
+    name: "summarize"
+    type: "int"
+    default_value {
+      i: 3
+    }
+  }
+}
+op {
   name: "StringJoin"
   input_arg {
     name: "inputs"
@@ -35954,6 +36047,38 @@
   is_stateful: true
 }
 op {
+  name: "UpperBound"
+  input_arg {
+    name: "sorted_inputs"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "values"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "out_type"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "out_type"
+    type: "type"
+    default_value {
+      type: DT_INT32
+    }
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+}
+op {
   name: "VarHandleOp"
   output_arg {
     name: "resource"
@@ -36199,9 +36324,21 @@
     type: DT_VARIANT
   }
   input_arg {
-    name: "window_size"
+    name: "size"
     type: DT_INT64
   }
+  input_arg {
+    name: "shift"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "stride"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "drop_remainder"
+    type: DT_BOOL
+  }
   output_arg {
     name: "handle"
     type: DT_VARIANT
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index ef8b15d..9915983 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -13,6 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
+#include "absl/strings/str_split.h"
 #include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/shape_inference.h"
@@ -102,6 +103,32 @@
     .Attr("fill: string = ''")
     .SetShapeFn(shape_inference::UnchangedShape);
 
+REGISTER_OP("StringFormat")
+    .Input("inputs: T")
+    .Output("output: string")
+    .Attr("T: list(type) >= 0")
+    .Attr("template: string = '%s'")
+    .Attr("placeholder: string = '%s'")
+    .Attr("summarize: int = 3")
+    .SetShapeFn([](InferenceContext* c) {
+      string template_;
+      string placeholder;
+      TF_RETURN_IF_ERROR(c->GetAttr("template", &template_));
+      TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder));
+
+      std::vector<std::string> split_template;
+      split_template = absl::StrSplit(template_, placeholder);
+      int64 num_placeholders = split_template.size() - 1;
+      if (c->num_inputs() != num_placeholders) {
+        return errors::InvalidArgument(strings::StrCat(
+            "num placeholders in template and num inputs must match: ",
+            num_placeholders, " vs. ", c->num_inputs()));
+      }
+
+      c->set_output(0, c->Scalar());
+      return Status::OK();
+    });
+
 REGISTER_OP("StringJoin")
     .Input("inputs: N * string")
     .Attr("N: int")
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 83228fa..83ea853 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -25,6 +25,7 @@
 #ifdef _WIN32
 #include <io.h>  // for _mktemp
 #endif
+#include "absl/base/macros.h"
 #include "include/json/json.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
@@ -63,7 +64,7 @@
 // The HTTP response code "308 Resume Incomplete".
 constexpr uint64 HTTP_CODE_RESUME_INCOMPLETE = 308;
 // The environment variable that overrides the size of the readahead buffer.
-// DEPRECATED. Use GCS_BLOCK_SIZE_MB instead.
+ABSL_DEPRECATED("Use GCS_BLOCK_SIZE_MB instead.")
 constexpr char kReadaheadBufferSize[] = "GCS_READAHEAD_BUFFER_SIZE_BYTES";
 // The environment variable that disables the GCS block cache for reads.
 // This is the explicit alternative to setting BLOCK_SIZE or MAX_SIZE to 0, and
diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl
index 3a012c2..37475fe 100644
--- a/tensorflow/core/platform/default/build_config_root.bzl
+++ b/tensorflow/core/platform/default/build_config_root.bzl
@@ -3,64 +3,64 @@
 # be separate to avoid cyclic references.
 
 def tf_cuda_tests_tags():
-  return ["requires-gpu"]
+    return ["requires-gpu", "local", "gpu"]
 
 def tf_sycl_tests_tags():
-  return ["requires-gpu"]
+    return ["requires-gpu", "local", "gpu"]
 
 def tf_additional_plugin_deps():
-  return select({
-      str(Label("//tensorflow:with_xla_support")): [
-          str(Label("//tensorflow/compiler/jit"))
-      ],
-      "//conditions:default": [],
-  })
+    return select({
+        str(Label("//tensorflow:with_xla_support")): [
+            str(Label("//tensorflow/compiler/jit")),
+        ],
+        "//conditions:default": [],
+    })
 
 def tf_additional_xla_deps_py():
-  return []
+    return []
 
 def tf_additional_grpc_deps_py():
-  return []
+    return []
 
 def tf_additional_license_deps():
-  return select({
-      str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
-      "//conditions:default": [],
-  })
+    return select({
+        str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
+        "//conditions:default": [],
+    })
 
 def tf_additional_verbs_deps():
-  return select({
-      str(Label("//tensorflow:with_verbs_support")): [
-          str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
-          str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
-      ],
-      "//conditions:default": [],
-  })
+    return select({
+        str(Label("//tensorflow:with_verbs_support")): [
+            str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
+            str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
+        ],
+        "//conditions:default": [],
+    })
 
 def tf_additional_mpi_deps():
-  return select({
-      str(Label("//tensorflow:with_mpi_support")): [
-          str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
-      ],
-      "//conditions:default": [],
-  })
+    return select({
+        str(Label("//tensorflow:with_mpi_support")): [
+            str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
+        ],
+        "//conditions:default": [],
+    })
 
 def tf_additional_gdr_deps():
-  return select({
-      str(Label("//tensorflow:with_gdr_support")): [
-          str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
-      ],
-      "//conditions:default": [],
-  })
+    return select({
+        str(Label("//tensorflow:with_gdr_support")): [
+            str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
+        ],
+        "//conditions:default": [],
+    })
 
-def if_static(extra_deps, otherwise=[]):
-  return select({
-      str(Label("//tensorflow:framework_shared_object")): otherwise,
-      "//conditions:default": extra_deps,
-  })
+def if_static(extra_deps, otherwise = []):
+    return select({
+        str(Label("//tensorflow:framework_shared_object")): otherwise,
+        "//conditions:default": extra_deps,
+    })
 
-def if_dynamic_kernels(extra_deps, otherwise=[]):
-  return select({
-      str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps,
-      "//conditions:default": otherwise,
-  })
+def if_dynamic_kernels(extra_deps, otherwise = []):
+    return select({
+        str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps,
+        "//conditions:default": otherwise,
+    })
diff --git a/tensorflow/core/platform/default/cord.h b/tensorflow/core/platform/default/cord.h
index 1ab6821..5823374 100644
--- a/tensorflow/core/platform/default/cord.h
+++ b/tensorflow/core/platform/default/cord.h
@@ -16,9 +16,6 @@
 #ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
 #define TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
 
-class Cord;
-namespace absl {
-using ::Cord;
-}  // namespace absl
+// TODO(ebrevdo): Fill this in.
 
 #endif  // TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 30059dc..156af6c 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -255,10 +255,13 @@
   /// \brief Append 'data' to the file.
   virtual Status Append(StringPiece data) = 0;
 
+  // TODO(ebrevdo): Remove this ifdef when absl is updated.
+#if defined(PLATFORM_GOOGLE)
   // \brief Append 'data' to the file.
   virtual Status Append(const absl::Cord& cord) {
     return errors::Unimplemented("Append(absl::Cord) is not implemented");
   }
+#endif
 
   /// \brief Close the file.
   ///
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 625d564..85cd023 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -68,7 +68,7 @@
   //    after the process starts.  Users are required to use vendor
   //    specific mechanisms (e.g., CUDA_VISIBLE_DEVICES) to control the
   //    physical to visible device mapping prior to invoking TensorFlow.
-  // 2. In the code, the ids in this list are also called "CUDA GPU id"s,
+  // 2. In the code, the ids in this list are also called "platform GPU id"s,
   //    and the 'virtual' ids of GPU devices (i.e. the ids in the device
   //    name "/device:GPU:<id>") are also called "TF GPU id"s. Please
   //    refer to third_party/tensorflow/core/common_runtime/gpu/gpu_id.h
diff --git a/tensorflow/core/protobuf/replay_log.proto b/tensorflow/core/protobuf/replay_log.proto
new file mode 100644
index 0000000..7644314
--- /dev/null
+++ b/tensorflow/core/protobuf/replay_log.proto
@@ -0,0 +1,47 @@
+syntax = "proto3";
+
+option cc_enable_arenas = true;
+package tensorflow;
+
+import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/protobuf/cluster.proto";
+import "tensorflow/core/protobuf/master.proto";
+
+// Records the creation of a new replay session.  We record the device listing
+// here to capture the state of the cluster.
+message NewReplaySession {
+  ListDevicesResponse devices = 1;
+  string session_handle = 2;
+}
+
+message ReplayOp {
+  double start_time_us = 31;
+  double end_time_us = 32;
+
+  oneof op {
+    CreateSessionRequest create_session = 1;
+    ExtendSessionRequest extend_session = 2;
+    PartialRunSetupRequest partial_run_setup = 3;
+    RunStepRequest run_step = 4;
+    CloseSessionRequest close_session = 5;
+    ListDevicesRequest list_devices = 6;
+    ResetRequest reset_request = 7;
+    MakeCallableRequest make_callable = 8;
+    RunCallableRequest run_callable = 9;
+    ReleaseCallableRequest release_callable = 10;
+    NewReplaySession new_replay_session = 11;
+  }
+
+  oneof response {
+    CreateSessionResponse create_session_response = 21;
+    ExtendSessionResponse extend_session_response = 22;
+    PartialRunSetupResponse partial_run_setup_response = 23;
+    RunStepResponse run_step_response = 24;
+    CloseSessionResponse close_session_response = 25;
+    ListDevicesResponse list_devices_response = 26;
+    ResetResponse reset_request_response = 27;
+    MakeCallableResponse make_callable_response = 28;
+    RunCallableResponse run_callable_response = 29;
+    ReleaseCallableResponse release_callable_response = 30;
+  }
+}
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 4129c93..b043a69 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -19,12 +19,12 @@
 // TensorFlow uses semantic versioning, see http://semver.org/.
 
 #define TF_MAJOR_VERSION 1
-#define TF_MINOR_VERSION 10
+#define TF_MINOR_VERSION 11
 #define TF_PATCH_VERSION 0
 
 // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
 // "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX ""
+#define TF_VERSION_SUFFIX "-rc1"
 
 #define TF_STR_HELPER(x) #x
 #define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index 540adb5..f6f0408 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -93,11 +93,11 @@
 }
 
 namespace cuda_helper {
-template <typename IntType>
-__device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
-  IntType* orig = first;
-  IntType* it = nullptr;
-  IntType step = 0;
+template <typename T, typename OutType = int32>
+__device__ OutType upper_bound(const T* first, OutType count, T val) {
+  const T* orig = first;
+  const T* it = nullptr;
+  OutType step = 0;
   while (count > 0) {
     it = first;
     step = count / 2;
@@ -112,6 +112,27 @@
 
   return first - orig;
 }
+
+template <typename T, typename OutType = int32>
+__device__ OutType lower_bound(const T* first, OutType count, T val) {
+  const T* orig = first;
+  const T* it = nullptr;
+  OutType step = 0;
+  while (count > 0) {
+    it = first;
+    step = count / 2;
+    it += step;
+    if (*it < val) {
+      first = ++it;
+      count -= step + 1;
+    } else {
+      count = step;
+    }
+  }
+
+  return first - orig;
+}
+
 }  // namespace cuda_helper
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index 0f04b65..b9ca8ab 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -20,6 +20,7 @@
 #include <numeric>
 #include <vector>
 
+#include "absl/base/macros.h"
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_types.h"
@@ -95,21 +96,21 @@
 
   SparseTensor() : dims_(0) {}
 
-  // DEPRECATED: use Create() functions instead of constructors directly.
+  ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
   SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
       : SparseTensor(ix, vals, TensorShapeToVector(shape),
                      UndefinedOrder(TensorShapeToVector(shape))) {}
 
-  // DEPRECATED: use Create() functions instead of constructors directly.
+  ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
   SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
       : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
 
-  // DEPRECATED: use Create() functions instead of constructors directly.
+  ABSL_DEPRECATED("use Create() functions instead of constructors directly.")
   SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
                const VarDimArray order)
       : SparseTensor(ix, vals, TensorShapeToVector(shape), order) {}
 
-  // DEPRECATED: use Create() functions instead of constructors directly.
+  ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
   SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
                const VarDimArray order)
       : ix_(ix),
@@ -237,9 +238,10 @@
   static Status Split(const SparseTensor& tensor, const int split_dim,
                       const int num_split, std::vector<SparseTensor>* result);
 
-  // DEPRECATED: use the form of Split() that takes an output pointer and
-  // returns a status instead.
   template <typename T>
+  ABSL_DEPRECATED(
+      "Use the form of Split() that takes an output pointer and returns a "
+      "status instead.")
   static std::vector<SparseTensor> Split(const SparseTensor& tensor,
                                          const int split_dim,
                                          const int num_split,
diff --git a/tensorflow/core/util/tensor_bundle/naming.h b/tensorflow/core/util/tensor_bundle/naming.h
index 6539d56..7b10197 100644
--- a/tensorflow/core/util/tensor_bundle/naming.h
+++ b/tensorflow/core/util/tensor_bundle/naming.h
@@ -35,6 +35,7 @@
 #define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_NAMING_H_
 
 #include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc
index f4bd295..74f0713 100644
--- a/tensorflow/core/util/work_sharder.cc
+++ b/tensorflow/core/util/work_sharder.cc
@@ -50,6 +50,8 @@
               max_parallelism);
 }
 
+// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you
+// to directly specify the shard size.
 void Sharder::Do(int64 total, int64 cost_per_unit, const Work& work,
                  const Runner& runner, int max_parallelism) {
   cost_per_unit = std::max(int64{1}, cost_per_unit);
diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h
index b12c31c..9db85a5 100644
--- a/tensorflow/core/util/work_sharder.h
+++ b/tensorflow/core/util/work_sharder.h
@@ -23,6 +23,9 @@
 
 namespace tensorflow {
 
+// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you
+// to directly specify the shard size. Use this function only if you want to
+// manually cap parallelism.
 // Shards the "total" unit of work assuming each unit of work having
 // roughly "cost_per_unit". Each unit of work is indexed 0, 1, ...,
 // total - 1. Each shard contains 1 or more units of work and the
diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD
index d4070fd..99da44d 100644
--- a/tensorflow/examples/tutorials/mnist/BUILD
+++ b/tensorflow/examples/tutorials/mnist/BUILD
@@ -84,6 +84,18 @@
 )
 
 py_binary(
+    name = "mnist_softmax_xla",
+    srcs = [
+        "mnist_softmax_xla.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":input_data",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+py_binary(
     name = "mnist_deep",
     srcs = [
         "mnist_deep.py",
diff --git a/tensorflow/go/README.md b/tensorflow/go/README.md
index 288a325..3989f9b 100644
--- a/tensorflow/go/README.md
+++ b/tensorflow/go/README.md
@@ -10,7 +10,7 @@
 
 ## Quickstart
 
-Refer to [Installing TensorFlow for Go](https://www.tensorflow.org/install/install_go)
+Refer to [Installing TensorFlow for Go](https://www.tensorflow.org/install/lang_go)
 
 ## Building the TensorFlow C library from source
 
@@ -23,9 +23,7 @@
 
 -   [bazel](https://www.bazel.build/versions/master/docs/install.html)
 -   Environment to build TensorFlow from source code
-    ([Linux](https://www.tensorflow.org/install/install_sources#PrepareLinux)
-    or [OS
-    X](https://www.tensorflow.org/install/install_sources#PrepareMac)).
+    ([Linux of macOS](https://www.tensorflow.org/install/source)).
     If you don't need GPU support, then try the following:
 
     ```sh
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 322b35d..1d72bcd 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -332,7 +332,7 @@
 // Creates a new tensor by applying sparse `updates` to individual values or
 // slices within a tensor (initially zero for numeric, empty for string) of
 // the given `shape` according to indices.  This operator is the inverse of the
-// @{tf.gather_nd} operator which extracts values or slices from a given tensor.
+// `tf.gather_nd` operator which extracts values or slices from a given tensor.
 //
 // If `indices` contains duplicates, then their updates are accumulated (summed).
 //
@@ -1473,7 +1473,7 @@
 //
 // value: a bitmask where a bit i being 1 means to ignore the begin
 // value and instead use the largest interval possible. At runtime
-// begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or
+// begin[i] will be replaced with `[0, n-1)` if `stride[i] > 0` or
 // `[-1, n-1]` if `stride[i] < 0`
 // If not specified, defaults to 0
 func StridedSliceBeginMask(value int64) StridedSliceAttr {
@@ -1856,6 +1856,32 @@
 	return op.Output(0)
 }
 
+// Ensures that the tensor's shape matches the expected shape.
+//
+// Raises an error if the input tensor's shape does not match the specified shape.
+// Returns the input tensor otherwise.
+//
+// Arguments:
+//	input: A tensor, whose shape is to be validated.
+//	shape: The expected (possibly partially specified) shape of the input tensor.
+//
+// Returns A tensor with the same shape and contents as the input tensor or value.
+func EnsureShape(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"shape": shape}
+	opspec := tf.OpSpec{
+		Type: "EnsureShape",
+		Input: []tf.Input{
+			input,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // UniqueWithCountsV2Attr is an optional argument to UniqueWithCountsV2.
 type UniqueWithCountsV2Attr func(optionalAttr)
 
@@ -2259,7 +2285,7 @@
 //
 //     output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]]
 //
-// Whereas in @{tf.gather} `indices` defines slices into the first
+// Whereas in `tf.gather` `indices` defines slices into the first
 // dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the
 // first `N` dimensions of `params`, where `N = indices.shape[-1]`.
 //
@@ -2356,6 +2382,8 @@
 //     output = [['b0', 'b1'], ['d0', 'c1']]
 // ```
 //
+// See also `tf.gather` and `tf.batch_gather`.
+//
 // Arguments:
 //	params: The tensor from which to gather values.
 //	indices: Index tensor.
@@ -2445,6 +2473,16 @@
 //                      [9, 9, 9]]
 // ```
 //
+// `tf.fill` differs from `tf.constant` in a few ways:
+//
+// *   `tf.fill` only supports scalar contents, whereas `tf.constant` supports
+//     Tensor values.
+// *   `tf.fill` creates an Op in the computation graph that constructs the actual
+//     Tensor value at runtime. This is in contrast to `tf.constant` which embeds
+//     the entire Tensor into the graph with a `Const` node.
+// *   Because `tf.fill` evaluates at graph runtime, it supports dynamic shapes
+//     based on other runtime Tensors, unlike `tf.constant`.
+//
 // Arguments:
 //	dims: 1-D. Represents the shape of the output tensor.
 //	value: 0-D (scalar). Value to fill the returned tensor.
@@ -2858,6 +2896,25 @@
 	return op.Output(0)
 }
 
+// Returns a constant tensor on the host. Only for writing C++ tests.
+//
+// Arguments:
+//	value: Attr `value` is the tensor to return.
+//
+func HostConst(scope *Scope, value tf.Tensor, dtype tf.DataType) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"value": value, "dtype": dtype}
+	opspec := tf.OpSpec{
+		Type: "HostConst",
+
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Splits a tensor into `num_split` tensors along one dimension.
 //
 // Arguments:
@@ -3377,6 +3434,204 @@
 	return op.Output(0)
 }
 
+// Bucketize each feature based on bucket boundaries.
+//
+// An op that returns a list of float tensors, where each tensor represents the
+// bucketized values for a single feature.
+//
+// Arguments:
+//	float_values: float; List of Rank 2 Tensor each containing float values for a single feature.
+//	bucket_boundaries: float; List of Rank 1 Tensors each containing the bucket boundaries for a single
+// feature.
+//
+// Returns int; List of Rank 2 Tensors each containing the bucketized values for a single feature.
+func BoostedTreesBucketize(scope *Scope, float_values []tf.Output, bucket_boundaries []tf.Output) (buckets []tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "BoostedTreesBucketize",
+		Input: []tf.Input{
+			tf.OutputList(float_values), tf.OutputList(bucket_boundaries),
+		},
+	}
+	op := scope.AddOperation(opspec)
+	if scope.Err() != nil {
+		return
+	}
+	var idx int
+	var err error
+	if buckets, idx, err = makeOutputList(op, idx, "buckets"); err != nil {
+		scope.UpdateErr("BoostedTreesBucketize", err)
+		return
+	}
+	return buckets
+}
+
+// BoostedTreesQuantileStreamResourceFlushAttr is an optional argument to BoostedTreesQuantileStreamResourceFlush.
+type BoostedTreesQuantileStreamResourceFlushAttr func(optionalAttr)
+
+// BoostedTreesQuantileStreamResourceFlushGenerateQuantiles sets the optional generate_quantiles attribute to value.
+//
+// value: bool; If True, the output will be the num_quantiles for each stream where the ith
+// entry is the ith quantile of the input with an approximation error of epsilon.
+// Duplicate values may be present.
+// If False, the output will be the points in the histogram that we got which roughly
+// translates to 1/epsilon boundaries and without any duplicates.
+// Default to False.
+// If not specified, defaults to false
+func BoostedTreesQuantileStreamResourceFlushGenerateQuantiles(value bool) BoostedTreesQuantileStreamResourceFlushAttr {
+	return func(m optionalAttr) {
+		m["generate_quantiles"] = value
+	}
+}
+
+// Flush the summaries for a quantile stream resource.
+//
+// An op that flushes the summaries for a quantile stream resource.
+//
+// Arguments:
+//	quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource.
+//	num_buckets: int; approximate number of buckets unless using generate_quantiles.
+//
+// Returns the created operation.
+func BoostedTreesQuantileStreamResourceFlush(scope *Scope, quantile_stream_resource_handle tf.Output, num_buckets tf.Output, optional ...BoostedTreesQuantileStreamResourceFlushAttr) (o *tf.Operation) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "BoostedTreesQuantileStreamResourceFlush",
+		Input: []tf.Input{
+			quantile_stream_resource_handle, num_buckets,
+		},
+		Attrs: attrs,
+	}
+	return scope.AddOperation(opspec)
+}
+
+// Add the quantile summaries to each quantile stream resource.
+//
+// An op that adds a list of quantile summaries to a quantile stream resource. Each
+// summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank)
+// for a single feature.
+//
+// Arguments:
+//	quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource.
+//	summaries: string; List of Rank 2 Tensor each containing the summaries for a single feature.
+//
+// Returns the created operation.
+func BoostedTreesQuantileStreamResourceAddSummaries(scope *Scope, quantile_stream_resource_handle tf.Output, summaries []tf.Output) (o *tf.Operation) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "BoostedTreesQuantileStreamResourceAddSummaries",
+		Input: []tf.Input{
+			quantile_stream_resource_handle, tf.OutputList(summaries),
+		},
+	}
+	return scope.AddOperation(opspec)
+}
+
+// Makes the summary of quantiles for the batch.
+//
+// An op that takes a list of tensors and outputs the quantile summaries for each tensor.
+//
+// Arguments:
+//	float_values: float; List of Rank 2 Tensors each containing values for a single feature.
+//	example_weights: float; Rank 1 Tensor with weights per instance.
+//	epsilon: float; The required maximum approximation error.
+//
+// Returns float; List of Rank 2 Tensors each containing the quantile summary (value, weight,
+// min_rank, max_rank) of a single feature.
+func BoostedTreesMakeQuantileSummaries(scope *Scope, float_values []tf.Output, example_weights tf.Output, epsilon tf.Output) (summaries []tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "BoostedTreesMakeQuantileSummaries",
+		Input: []tf.Input{
+			tf.OutputList(float_values), example_weights, epsilon,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	if scope.Err() != nil {
+		return
+	}
+	var idx int
+	var err error
+	if summaries, idx, err = makeOutputList(op, idx, "summaries"); err != nil {
+		scope.UpdateErr("BoostedTreesMakeQuantileSummaries", err)
+		return
+	}
+	return summaries
+}
+
+// BoostedTreesCreateQuantileStreamResourceAttr is an optional argument to BoostedTreesCreateQuantileStreamResource.
+type BoostedTreesCreateQuantileStreamResourceAttr func(optionalAttr)
+
+// BoostedTreesCreateQuantileStreamResourceMaxElements sets the optional max_elements attribute to value.
+//
+// value: int; The maximum number of data points that can be fed to the stream.
+// If not specified, defaults to 1099511627776
+func BoostedTreesCreateQuantileStreamResourceMaxElements(value int64) BoostedTreesCreateQuantileStreamResourceAttr {
+	return func(m optionalAttr) {
+		m["max_elements"] = value
+	}
+}
+
+// Create the Resource for Quantile Streams.
+//
+// Arguments:
+//	quantile_stream_resource_handle: resource; Handle to quantile stream resource.
+//	epsilon: float; The required approximation error of the stream resource.
+//	num_streams: int; The number of streams managed by the resource that shares the same epsilon.
+//
+// Returns the created operation.
+func BoostedTreesCreateQuantileStreamResource(scope *Scope, quantile_stream_resource_handle tf.Output, epsilon tf.Output, num_streams tf.Output, optional ...BoostedTreesCreateQuantileStreamResourceAttr) (o *tf.Operation) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "BoostedTreesCreateQuantileStreamResource",
+		Input: []tf.Input{
+			quantile_stream_resource_handle, epsilon, num_streams,
+		},
+		Attrs: attrs,
+	}
+	return scope.AddOperation(opspec)
+}
+
+// Checks whether a quantile stream has been initialized.
+//
+// An Op that checks if quantile stream resource is initialized.
+//
+// Arguments:
+//	quantile_stream_resource_handle: resource; The reference to quantile stream resource handle.
+//
+// Returns bool; True if the resource is initialized, False otherwise.
+func IsBoostedTreesQuantileStreamResourceInitialized(scope *Scope, quantile_stream_resource_handle tf.Output) (is_initialized tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "IsBoostedTreesQuantileStreamResourceInitialized",
+		Input: []tf.Input{
+			quantile_stream_resource_handle,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Calculates the prior from the training data (the bias) and fills in the first node with the logits' prior. Returns a boolean indicating whether to continue centering.
 //
 // Arguments:
@@ -3486,97 +3741,28 @@
 	return op.Output(0)
 }
 
-// Computes the sum along sparse segments of a tensor.
+// Makes the summary of accumulated stats for the batch.
 //
-// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
-// misisng, the `output` tensor at that position will be zeroed.
-//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
-//
-// For example:
-//
-// ```python
-// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
-//
-// tf.sparse_segment_sum_with_num_segments(
-//     c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
-// # => [[0 0 0 0]
-// #     [0 0 0 0]
-// #     [0 0 0 0]]
-//
-// tf.sparse_segment_sum_with_num_segments(c,
-//                                         tf.constant([0, 1]),
-//                                         tf.constant([0, 2],
-//                                         num_segments=4))
-// # => [[ 1  2  3  4]
-// #     [ 0  0  0  0]
-// #     [-1 -2 -3 -4]
-// #     [ 0  0  0  0]]
-// ```
+// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example.
 //
 // Arguments:
+//	node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer.
+//	gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients.
+//	hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians.
+//	bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column).
+//	max_splits: int; the maximum number of splits possible in the whole tree.
+//	num_buckets: int; equals to the maximum possible value of bucketized feature.
 //
-//	indices: A 1-D tensor. Has same rank as `segment_ids`.
-//	segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
-//	num_segments: Should equal the number of distinct segment IDs.
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
-func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians.
+func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) {
 	if scope.Err() != nil {
 		return
 	}
+	attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets}
 	opspec := tf.OpSpec{
-		Type: "SparseSegmentSumWithNumSegments",
+		Type: "BoostedTreesMakeStatsSummary",
 		Input: []tf.Input{
-			data, indices, segment_ids, num_segments,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// PreventGradientAttr is an optional argument to PreventGradient.
-type PreventGradientAttr func(optionalAttr)
-
-// PreventGradientMessage sets the optional message attribute to value.
-//
-// value: Will be printed in the error when anyone tries to differentiate
-// this operation.
-// If not specified, defaults to ""
-func PreventGradientMessage(value string) PreventGradientAttr {
-	return func(m optionalAttr) {
-		m["message"] = value
-	}
-}
-
-// An identity op that triggers an error if a gradient is requested.
-//
-// When executed in a graph, this op outputs its input tensor as-is.
-//
-// When building ops to compute gradients, the TensorFlow gradient system
-// will return an error when trying to lookup the gradient of this op,
-// because no gradient must ever be registered for this function.  This
-// op exists to prevent subtle bugs from silently returning unimplemented
-// gradients in some corner cases.
-//
-// Arguments:
-//	input: any tensor.
-//
-// Returns the same input tensor.
-func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "PreventGradient",
-		Input: []tf.Input{
-			input,
+			node_ids, gradients, hessians, tf.OutputList(bucketized_features_list),
 		},
 		Attrs: attrs,
 	}
@@ -3584,25 +3770,11 @@
 	return op.Output(0)
 }
 
-// Computes asin of x element-wise.
-func Asin(scope *Scope, x tf.Output) (y tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "Asin",
-		Input: []tf.Input{
-			x,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Computes the sum along sparse segments of a tensor.
 //
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
 //
 // Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
 // dimension, selecting a subset of dimension 0, specified by `indices`.
@@ -3668,28 +3840,32 @@
 
 // Computes the minimum along segments of a tensor.
 //
-// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+// for an explanation of segments.
 //
 // This operator is similar to the unsorted segment sum operator found
 // [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
 // Instead of computing the sum over segments, it computes the minimum such that:
 //
-// \\(output_i = \min_j data_j\\) where min is over `j` such
-// that `segment_ids[j] == i`.
+// \\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such
+// that `segment_ids[j...] == i`.
 //
 // If the minimum is empty for a given segment ID `i`, it outputs the largest
 // possible value for the specific numeric type,
 // `output[i] = numeric_limits<T>::max()`.
 //
+// If the given segment ID `i` is negative, then the corresponding value is
+// dropped, and will not be included in the result.
+//
 // Arguments:
 //
-//	segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
-// first dimension.
+//	segment_ids: A tensor whose shape is a prefix of `data.shape`.
 //
 //
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
+// Returns Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
 func UnsortedSegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
 	if scope.Err() != nil {
 		return
@@ -3721,11 +3897,12 @@
 
 // Computes the sum along segments of a tensor.
 //
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
 //
 // Computes a tensor such that
-// \\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
+// \\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
 // that `segment_ids[j...] == i`.  Unlike `SegmentSum`, `segment_ids`
 // need not be sorted and need not cover all values in the full
 // range of valid values.
@@ -4302,97 +4479,6 @@
 	return op.Output(0)
 }
 
-// NthElementAttr is an optional argument to NthElement.
-type NthElementAttr func(optionalAttr)
-
-// NthElementReverse sets the optional reverse attribute to value.
-//
-// value: When set to True, find the nth-largest value in the vector and vice
-// versa.
-// If not specified, defaults to false
-func NthElementReverse(value bool) NthElementAttr {
-	return func(m optionalAttr) {
-		m["reverse"] = value
-	}
-}
-
-// Finds values of the `n`-th order statistic for the last dimension.
-//
-// If the input is a vector (rank-1), finds the entries which is the nth-smallest
-// value in the vector and outputs their values as scalar tensor.
-//
-// For matrices (resp. higher rank input), computes the entries which is the
-// nth-smallest value in each row (resp. vector along the last dimension). Thus,
-//
-//     values.shape = input.shape[:-1]
-//
-// Arguments:
-//	input: 1-D or higher with last dimension at least `n+1`.
-//	n: 0-D. Position of sorted vector to select along the last dimension (along
-// each row for matrices). Valid range of n is `[0, input.shape[:-1])`
-//
-// Returns The `n`-th order statistic along each last dimensional slice.
-func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "NthElement",
-		Input: []tf.Input{
-			input, n,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// Computes the maximum along segments of a tensor.
-//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
-//
-// This operator is similar to the unsorted segment sum operator found
-// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
-// Instead of computing the sum over segments, it computes the maximum such that:
-//
-// \\(output_i = \max_j data_j\\) where max is over `j` such
-// that `segment_ids[j] == i`.
-//
-// If the maximum is empty for a given segment ID `i`, it outputs the smallest
-// possible value for the specific numeric type,
-// `output[i] = numeric_limits<T>::lowest()`.
-//
-// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
-// </div>
-//
-// Arguments:
-//
-//	segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
-// first dimension.
-//
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
-func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "UnsortedSegmentMax",
-		Input: []tf.Input{
-			data, segment_ids, num_segments,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Computes exponential of x element-wise.  \\(y = e^x\\).
 func Exp(scope *Scope, x tf.Output) (y tf.Output) {
 	if scope.Err() != nil {
@@ -4500,6 +4586,218 @@
 	return op.Output(0), op.Output(1), op.Output(2)
 }
 
+// PreventGradientAttr is an optional argument to PreventGradient.
+type PreventGradientAttr func(optionalAttr)
+
+// PreventGradientMessage sets the optional message attribute to value.
+//
+// value: Will be printed in the error when anyone tries to differentiate
+// this operation.
+// If not specified, defaults to ""
+func PreventGradientMessage(value string) PreventGradientAttr {
+	return func(m optionalAttr) {
+		m["message"] = value
+	}
+}
+
+// An identity op that triggers an error if a gradient is requested.
+//
+// When executed in a graph, this op outputs its input tensor as-is.
+//
+// When building ops to compute gradients, the TensorFlow gradient system
+// will return an error when trying to lookup the gradient of this op,
+// because no gradient must ever be registered for this function.  This
+// op exists to prevent subtle bugs from silently returning unimplemented
+// gradients in some corner cases.
+//
+// Arguments:
+//	input: any tensor.
+//
+// Returns the same input tensor.
+func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "PreventGradient",
+		Input: []tf.Input{
+			input,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Computes asin of x element-wise.
+func Asin(scope *Scope, x tf.Output) (y tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "Asin",
+		Input: []tf.Input{
+			x,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Computes the maximum along segments of a tensor.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// This operator is similar to the unsorted segment sum operator found
+// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
+// Instead of computing the sum over segments, it computes the maximum such that:
+//
+// \\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such
+// that `segment_ids[j...] == i`.
+//
+// If the maximum is empty for a given segment ID `i`, it outputs the smallest
+// possible value for the specific numeric type,
+// `output[i] = numeric_limits<T>::lowest()`.
+//
+// If the given segment ID `i` is negative, then the corresponding value is
+// dropped, and will not be included in the result.
+//
+// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
+// </div>
+//
+// Arguments:
+//
+//	segment_ids: A tensor whose shape is a prefix of `data.shape`.END
+//   }
+//   out_arg {
+//     name: "output"
+//     description: <<END
+// Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
+//
+func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "UnsortedSegmentMax",
+		Input: []tf.Input{
+			data, segment_ids, num_segments,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// NthElementAttr is an optional argument to NthElement.
+type NthElementAttr func(optionalAttr)
+
+// NthElementReverse sets the optional reverse attribute to value.
+//
+// value: When set to True, find the nth-largest value in the vector and vice
+// versa.
+// If not specified, defaults to false
+func NthElementReverse(value bool) NthElementAttr {
+	return func(m optionalAttr) {
+		m["reverse"] = value
+	}
+}
+
+// Finds values of the `n`-th order statistic for the last dimension.
+//
+// If the input is a vector (rank-1), finds the entries which is the nth-smallest
+// value in the vector and outputs their values as scalar tensor.
+//
+// For matrices (resp. higher rank input), computes the entries which is the
+// nth-smallest value in each row (resp. vector along the last dimension). Thus,
+//
+//     values.shape = input.shape[:-1]
+//
+// Arguments:
+//	input: 1-D or higher with last dimension at least `n+1`.
+//	n: 0-D. Position of sorted vector to select along the last dimension (along
+// each row for matrices). Valid range of n is `[0, input.shape[:-1])`
+//
+// Returns The `n`-th order statistic along each last dimensional slice.
+func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "NthElement",
+		Input: []tf.Input{
+			input, n,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Computes the sum along sparse segments of a tensor.
+//
+// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
+// misisng, the `output` tensor at that position will be zeroed.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// For example:
+//
+// ```python
+// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+//
+// tf.sparse_segment_sum_with_num_segments(
+//     c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
+// # => [[0 0 0 0]
+// #     [0 0 0 0]
+// #     [0 0 0 0]]
+//
+// tf.sparse_segment_sum_with_num_segments(c,
+//                                         tf.constant([0, 1]),
+//                                         tf.constant([0, 2],
+//                                         num_segments=4))
+// # => [[ 1  2  3  4]
+// #     [ 0  0  0  0]
+// #     [-1 -2 -3 -4]
+// #     [ 0  0  0  0]]
+// ```
+//
+// Arguments:
+//
+//	indices: A 1-D tensor. Has same rank as `segment_ids`.
+//	segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+//	num_segments: Should equal the number of distinct segment IDs.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `num_segments`.
+func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "SparseSegmentSumWithNumSegments",
+		Input: []tf.Input{
+			data, indices, segment_ids, num_segments,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Computes the determinant of one or more square matrices.
 //
 // The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
@@ -5225,6 +5523,47 @@
 	return op.Output(0)
 }
 
+// Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features.
+//
+// Arguments:
+//
+//
+//	dense_defaults: A dict mapping string keys to `Tensor`s.
+// The keys of the dict must match the dense_keys of the feature.
+//	sparse_keys: A list of string keys in the examples features.
+// The results for these keys will be returned as `SparseTensor` objects.
+//	dense_keys: A list of Ndense string Tensors (scalars).
+// The keys expected in the Examples features associated with dense values.
+//	sparse_types: A list of `DTypes` of the same length as `sparse_keys`.
+// Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+// and `tf.string` (`BytesList`) are supported.
+//	dense_shapes: List of tuples with the same length as `dense_keys`.
+// The shape of the data for each dense feature referenced by `dense_keys`.
+// Required for any input tensors identified by `dense_keys`.  Must be
+// either fully defined, or may contain an unknown first dimension.
+// An unknown first dimension means the feature is treated as having
+// a variable number of blocks, and the output shape along this dimension
+// is considered unknown at graph build time.  Padding is applied for
+// minibatch elements smaller than the maximum number of blocks for the
+// given feature along this dimension.
+//	output_types: The type list for the return values.
+//	output_shapes: The list of shapes being produced.
+func ParseExampleDataset(scope *Scope, input_dataset tf.Output, num_parallel_calls tf.Output, dense_defaults []tf.Output, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes, "output_types": output_types, "output_shapes": output_shapes}
+	opspec := tf.OpSpec{
+		Type: "ParseExampleDataset",
+		Input: []tf.Input{
+			input_dataset, num_parallel_calls, tf.OutputList(dense_defaults),
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Returns a batched matrix tensor with new batched diagonal values.
 //
 // Given `input` and `diagonal`, this operation returns a tensor with the
@@ -6476,7 +6815,7 @@
 	return offset
 }
 
-// Compute the lower regularized incomplete Gamma function `Q(a, x)`.
+// Compute the lower regularized incomplete Gamma function `P(a, x)`.
 //
 // The lower regularized incomplete Gamma function is defined as:
 //
@@ -7910,6 +8249,214 @@
 	return components
 }
 
+// ParseSequenceExampleAttr is an optional argument to ParseSequenceExample.
+type ParseSequenceExampleAttr func(optionalAttr)
+
+// ParseSequenceExampleNcontextSparse sets the optional Ncontext_sparse attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func ParseSequenceExampleNcontextSparse(value int64) ParseSequenceExampleAttr {
+	return func(m optionalAttr) {
+		m["Ncontext_sparse"] = value
+	}
+}
+
+// ParseSequenceExampleNcontextDense sets the optional Ncontext_dense attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func ParseSequenceExampleNcontextDense(value int64) ParseSequenceExampleAttr {
+	return func(m optionalAttr) {
+		m["Ncontext_dense"] = value
+	}
+}
+
+// ParseSequenceExampleNfeatureListSparse sets the optional Nfeature_list_sparse attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func ParseSequenceExampleNfeatureListSparse(value int64) ParseSequenceExampleAttr {
+	return func(m optionalAttr) {
+		m["Nfeature_list_sparse"] = value
+	}
+}
+
+// ParseSequenceExampleNfeatureListDense sets the optional Nfeature_list_dense attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func ParseSequenceExampleNfeatureListDense(value int64) ParseSequenceExampleAttr {
+	return func(m optionalAttr) {
+		m["Nfeature_list_dense"] = value
+	}
+}
+
+// ParseSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value.
+//
+// value: A list of Ncontext_sparse types; the data types of data in
+// each context Feature given in context_sparse_keys.
+// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleContextSparseTypes(value []tf.DataType) ParseSequenceExampleAttr {
+	return func(m optionalAttr) {
+		m["context_sparse_types"] = value
+	}
+}
+
+// ParseSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value.
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSequenceExampleAttr {
+	return func(m optionalAttr) {
+		m["feature_list_dense_types"] = value
+	}
+}
+
+// ParseSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value.
+//
+// value: A list of Ncontext_dense shapes; the shapes of data in
+// each context Feature given in context_dense_keys.
+// The number of elements in the Feature corresponding to context_dense_key[j]
+// must always equal context_dense_shapes[j].NumEntries().
+// The shape of context_dense_values[j] will match context_dense_shapes[j].
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleContextDenseShapes(value []tf.Shape) ParseSequenceExampleAttr {
+	return func(m optionalAttr) {
+		m["context_dense_shapes"] = value
+	}
+}
+
+// ParseSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value.
+//
+// value: A list of Nfeature_list_sparse types; the data types
+// of data in each FeatureList given in feature_list_sparse_keys.
+// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSequenceExampleAttr {
+	return func(m optionalAttr) {
+		m["feature_list_sparse_types"] = value
+	}
+}
+
+// ParseSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value.
+//
+// value: A list of Nfeature_list_dense shapes; the shapes of
+// data in each FeatureList given in feature_list_dense_keys.
+// The shape of each Feature in the FeatureList corresponding to
+// feature_list_dense_key[j] must always equal
+// feature_list_dense_shapes[j].NumEntries().
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSequenceExampleAttr {
+	return func(m optionalAttr) {
+		m["feature_list_dense_shapes"] = value
+	}
+}
+
+// Transforms a vector of brain.SequenceExample protos (as strings) into typed tensors.
+//
+// Arguments:
+//	serialized: A vector containing binary serialized SequenceExample protos.
+//	debug_name: A vector containing the names of the serialized protos.
+// May contain, for example, table key (descriptive) name for the
+// corresponding serialized proto.  This is purely useful for debugging
+// purposes, and the presence of values here has no effect on the output.
+// May also be an empty vector if no name is available.
+//	context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty).
+// context_dense_defaults[j] provides default values
+// when the SequenceExample's context map lacks context_dense_key[j].
+// If an empty Tensor is provided for context_dense_defaults[j],
+// then the Feature context_dense_keys[j] is required.
+// The input type is inferred from context_dense_defaults[j], even when it's
+// empty.  If context_dense_defaults[j] is not empty, its shape must match
+// context_dense_shapes[j].
+//	feature_list_dense_missing_assumed_empty: A vector listing the
+// FeatureList keys which may be missing from the SequenceExamples.  If the
+// associated FeatureList is missing, it is treated as empty.  By default,
+// any FeatureList not listed in this vector must exist in the SequenceExamples.
+//	context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars).
+// The keys expected in the Examples' features associated with context_sparse
+// values.
+//	context_dense_keys: A list of Ncontext_dense string Tensors (scalars).
+// The keys expected in the SequenceExamples' context features associated with
+// dense values.
+//	feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors
+// (scalars).  The keys expected in the FeatureLists associated with sparse
+// values.
+//	feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars).
+// The keys expected in the SequenceExamples' feature_lists associated
+// with lists of dense values.
+func ParseSequenceExample(scope *Scope, serialized tf.Output, debug_name tf.Output, context_dense_defaults []tf.Output, feature_list_dense_missing_assumed_empty []string, context_sparse_keys []string, context_dense_keys []string, feature_list_sparse_keys []string, feature_list_dense_keys []string, optional ...ParseSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output, feature_list_dense_lengths []tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"feature_list_dense_missing_assumed_empty": feature_list_dense_missing_assumed_empty, "context_sparse_keys": context_sparse_keys, "context_dense_keys": context_dense_keys, "feature_list_sparse_keys": feature_list_sparse_keys, "feature_list_dense_keys": feature_list_dense_keys}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "ParseSequenceExample",
+		Input: []tf.Input{
+			serialized, debug_name, tf.OutputList(context_dense_defaults),
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	if scope.Err() != nil {
+		return
+	}
+	var idx int
+	var err error
+	if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil {
+		scope.UpdateErr("ParseSequenceExample", err)
+		return
+	}
+	if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil {
+		scope.UpdateErr("ParseSequenceExample", err)
+		return
+	}
+	if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil {
+		scope.UpdateErr("ParseSequenceExample", err)
+		return
+	}
+	if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil {
+		scope.UpdateErr("ParseSequenceExample", err)
+		return
+	}
+	if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil {
+		scope.UpdateErr("ParseSequenceExample", err)
+		return
+	}
+	if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil {
+		scope.UpdateErr("ParseSequenceExample", err)
+		return
+	}
+	if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil {
+		scope.UpdateErr("ParseSequenceExample", err)
+		return
+	}
+	if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil {
+		scope.UpdateErr("ParseSequenceExample", err)
+		return
+	}
+	if feature_list_dense_lengths, idx, err = makeOutputList(op, idx, "feature_list_dense_lengths"); err != nil {
+		scope.UpdateErr("ParseSequenceExample", err)
+		return
+	}
+	return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values, feature_list_dense_lengths
+}
+
 // Computes the Gauss error function of `x` element-wise.
 func Erf(scope *Scope, x tf.Output) (y tf.Output) {
 	if scope.Err() != nil {
@@ -8711,6 +9258,66 @@
 	return op.Output(0)
 }
 
+// RandomUniformIntAttr is an optional argument to RandomUniformInt.
+type RandomUniformIntAttr func(optionalAttr)
+
+// RandomUniformIntSeed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed.  Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomUniformIntSeed(value int64) RandomUniformIntAttr {
+	return func(m optionalAttr) {
+		m["seed"] = value
+	}
+}
+
+// RandomUniformIntSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomUniformIntSeed2(value int64) RandomUniformIntAttr {
+	return func(m optionalAttr) {
+		m["seed2"] = value
+	}
+}
+
+// Outputs random integers from a uniform distribution.
+//
+// The generated values are uniform integers in the range `[minval, maxval)`.
+// The lower bound `minval` is included in the range, while the upper bound
+// `maxval` is excluded.
+//
+// The random integers are slightly biased unless `maxval - minval` is an exact
+// power of two.  The bias is small for values of `maxval - minval` significantly
+// smaller than the range of the output (either `2^32` or `2^64`).
+//
+// Arguments:
+//	shape: The shape of the output tensor.
+//	minval: 0-D.  Inclusive lower bound on the generated integers.
+//	maxval: 0-D.  Exclusive upper bound on the generated integers.
+//
+// Returns A tensor of the specified shape filled with uniform random integers.
+func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "RandomUniformInt",
+		Input: []tf.Input{
+			shape, minval, maxval,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl.
 type ResourceApplyFtrlAttr func(optionalAttr)
 
@@ -9188,6 +9795,49 @@
 	return op.Output(0)
 }
 
+// StaticRegexReplaceAttr is an optional argument to StaticRegexReplace.
+type StaticRegexReplaceAttr func(optionalAttr)
+
+// StaticRegexReplaceReplaceGlobal sets the optional replace_global attribute to value.
+//
+// value: If True, the replacement is global, otherwise the replacement
+// is done only on the first match.
+// If not specified, defaults to true
+func StaticRegexReplaceReplaceGlobal(value bool) StaticRegexReplaceAttr {
+	return func(m optionalAttr) {
+		m["replace_global"] = value
+	}
+}
+
+// Replaces the match of pattern in input with rewrite.
+//
+// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+//
+// Arguments:
+//	input: The text to be processed.
+//	pattern: The regular expression to match the input.
+//	rewrite: The rewrite to be applied to the matched expresion.
+//
+// Returns The text after applying pattern and rewrite.
+func StaticRegexReplace(scope *Scope, input tf.Output, pattern string, rewrite string, optional ...StaticRegexReplaceAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"pattern": pattern, "rewrite": rewrite}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "StaticRegexReplace",
+		Input: []tf.Input{
+			input,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Computes gradients for the exponential linear (Elu) operation.
 //
 // Arguments:
@@ -10054,7 +10704,7 @@
 //
 //     [1, 12, 3, 14, 14, 6, 7, 20]
 //
-// See @{tf.scatter_nd} for more details about how to make updates to
+// See `tf.scatter_nd` for more details about how to make updates to
 // slices.
 //
 // Arguments:
@@ -11365,36 +12015,27 @@
 	return op.Output(0)
 }
 
-// The gradient operator for the SparseAdd op.
+// String lengths of `input`.
 //
-// The SparseAdd op calculates A + B, where A, B, and the sum are all represented
-// as `SparseTensor` objects.  This op takes in the upstream gradient w.r.t.
-// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
-// values of A and B.
+// Computes the length of each string given in the input tensor.
 //
 // Arguments:
-//	backprop_val_grad: 1-D with shape `[nnz(sum)]`.  The gradient with respect to
-// the non-empty values of the sum.
-//	a_indices: 2-D.  The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`.
-//	b_indices: 2-D.  The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`.
-//	sum_indices: 2-D.  The `indices` of the sum `SparseTensor`, size
-// `[nnz(sum), ndims]`.
+//	input: The string for which to compute the length.
 //
-// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the
-// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the
-// non-empty values of B.
-func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) {
+// Returns Integer tensor that has the same shape as `input`. The output contains the
+// element-wise string lengths of `input`.
+func StringLength(scope *Scope, input tf.Output) (output tf.Output) {
 	if scope.Err() != nil {
 		return
 	}
 	opspec := tf.OpSpec{
-		Type: "SparseAddGrad",
+		Type: "StringLength",
 		Input: []tf.Input{
-			backprop_val_grad, a_indices, b_indices, sum_indices,
+			input,
 		},
 	}
 	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1)
+	return op.Output(0)
 }
 
 // Converts each string in the input Tensor to its hash mod by a number of buckets.
@@ -11747,7 +12388,7 @@
 //
 //     [1, 11, 3, 10, 9, 6, 7, 12]
 //
-// See @{tf.scatter_nd} for more details about how to make updates to
+// See `tf.scatter_nd` for more details about how to make updates to
 // slices.
 //
 // Arguments:
@@ -12230,10 +12871,128 @@
 	return op.Output(0)
 }
 
+// ShapeAttr is an optional argument to Shape.
+type ShapeAttr func(optionalAttr)
+
+// ShapeOutType sets the optional out_type attribute to value.
+// If not specified, defaults to DT_INT32
+func ShapeOutType(value tf.DataType) ShapeAttr {
+	return func(m optionalAttr) {
+		m["out_type"] = value
+	}
+}
+
+// Returns the shape of a tensor.
+//
+// This operation returns a 1-D integer tensor representing the shape of `input`.
+//
+// For example:
+//
+// ```
+// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
+// shape(t) ==> [2, 2, 3]
+// ```
+func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "Shape",
+		Input: []tf.Input{
+			input,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Computes the power of one value to another.
+//
+// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
+// corresponding elements in `x` and `y`. For example:
+//
+// ```
+// # tensor 'x' is [[2, 2]], [3, 3]]
+// # tensor 'y' is [[8, 16], [2, 3]]
+// tf.pow(x, y) ==> [[256, 65536], [9, 27]]
+// ```
+func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "Pow",
+		Input: []tf.Input{
+			x, y,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Computes fingerprints of the input strings.
+//
+// Arguments:
+//	input: vector of strings to compute fingerprints on.
+//
+// Returns a (N,2) shaped matrix where N is the number of elements in the input
+// vector. Each row contains the low and high parts of the fingerprint.
+func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "SdcaFprint",
+		Input: []tf.Input{
+			input,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// The gradient operator for the SparseAdd op.
+//
+// The SparseAdd op calculates A + B, where A, B, and the sum are all represented
+// as `SparseTensor` objects.  This op takes in the upstream gradient w.r.t.
+// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
+// values of A and B.
+//
+// Arguments:
+//	backprop_val_grad: 1-D with shape `[nnz(sum)]`.  The gradient with respect to
+// the non-empty values of the sum.
+//	a_indices: 2-D.  The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`.
+//	b_indices: 2-D.  The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`.
+//	sum_indices: 2-D.  The `indices` of the sum `SparseTensor`, size
+// `[nnz(sum), ndims]`.
+//
+// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the
+// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the
+// non-empty values of B.
+func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "SparseAddGrad",
+		Input: []tf.Input{
+			backprop_val_grad, a_indices, b_indices, sum_indices,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0), op.Output(1)
+}
+
 // Computes the mean along segments of a tensor.
 //
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
 //
 // Computes a tensor such that
 // \\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is
@@ -12248,7 +13007,7 @@
 //
 // Arguments:
 //
-//	segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+//	segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
 // first dimension.  Values should be sorted and can be repeated.
 //
 // Returns Has same shape as data, except for dimension 0 which
@@ -12367,7 +13126,7 @@
 //
 // Arguments:
 //	input: A string tensor of the text to be processed.
-//	pattern: A 1-D string tensor of the regular expression to match the input.
+//	pattern: A scalar string tensor containing the regular expression to match the input.
 //
 // Returns A bool tensor with the same shape as `input`.
 func RegexFullMatch(scope *Scope, input tf.Output, pattern tf.Output) (output tf.Output) {
@@ -12421,6 +13180,79 @@
 	return op.Output(0)
 }
 
+// RandomPoissonV2Attr is an optional argument to RandomPoissonV2.
+type RandomPoissonV2Attr func(optionalAttr)
+
+// RandomPoissonV2Seed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed.  Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr {
+	return func(m optionalAttr) {
+		m["seed"] = value
+	}
+}
+
+// RandomPoissonV2Seed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr {
+	return func(m optionalAttr) {
+		m["seed2"] = value
+	}
+}
+
+// RandomPoissonV2Dtype sets the optional dtype attribute to value.
+// If not specified, defaults to DT_INT64
+func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr {
+	return func(m optionalAttr) {
+		m["dtype"] = value
+	}
+}
+
+// Outputs random values from the Poisson distribution(s) described by rate.
+//
+// This op uses two algorithms, depending on rate. If rate >= 10, then
+// the algorithm by Hormann is used to acquire samples via
+// transformation-rejection.
+// See http://www.sciencedirect.com/science/article/pii/0167668793909974.
+//
+// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
+// random variables.
+// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
+// Programming, Volume 2. Addison Wesley
+//
+// Arguments:
+//	shape: 1-D integer tensor. Shape of independent samples to draw from each
+// distribution described by the shape parameters given in rate.
+//	rate: A tensor in which each scalar is a "rate" parameter describing the
+// associated poisson distribution.
+//
+// Returns A tensor with shape `shape + shape(rate)`. Each slice
+// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
+// `rate[i0, i1, ...iN]`.
+func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "RandomPoissonV2",
+		Input: []tf.Input{
+			shape, rate,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg.
 type DecodeAndCropJpegAttr func(optionalAttr)
 
@@ -14443,6 +15275,25 @@
 	return scope.AddOperation(opspec)
 }
 
+// Returns 0 if the denominator is zero.
+//
+//
+// *NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func DivNoNan(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "DivNoNan",
+		Input: []tf.Input{
+			x, y,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Computes the gradient for the sqrt of `x` wrt its input.
 //
 // Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy`
@@ -15350,6 +16201,36 @@
 	return op.Output(0)
 }
 
+// Check if the input matches the regex pattern.
+//
+// The input is a string tensor of any shape. The pattern is the
+// regular expression to be matched with every element of the input tensor.
+// The boolean values (True or False) of the output tensor indicate
+// if the input matches the regex pattern provided.
+//
+// The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+//
+// Arguments:
+//	input: A string tensor of the text to be processed.
+//	pattern: The regular expression to match the input.
+//
+// Returns A bool tensor with the same shape as `input`.
+func StaticRegexFullMatch(scope *Scope, input tf.Output, pattern string) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"pattern": pattern}
+	opspec := tf.OpSpec{
+		Type: "StaticRegexFullMatch",
+		Input: []tf.Input{
+			input,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent.
 type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr)
 
@@ -15947,6 +16828,23 @@
 	return scope.AddOperation(opspec)
 }
 
+// Creates a dataset containing elements of first component of `input_dataset` having true in the last component.
+func FilterByLastComponentDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+	opspec := tf.OpSpec{
+		Type: "FilterByLastComponentDataset",
+		Input: []tf.Input{
+			input_dataset,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // CudnnRNNCanonicalToParamsAttr is an optional argument to CudnnRNNCanonicalToParams.
 type CudnnRNNCanonicalToParamsAttr func(optionalAttr)
 
@@ -16806,7 +17704,8 @@
 //	records: Each string is a record/row in the csv and all records should have
 // the same format.
 //	record_defaults: One tensor per column of the input record, with either a
-// scalar default value for that column or empty if the column is required.
+// scalar default value for that column or an empty vector if the column is
+// required.
 //
 // Returns Each tensor will have the same shape as records.
 func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, optional ...DecodeCSVAttr) (output []tf.Output) {
@@ -17573,8 +18472,9 @@
 
 // Computes the sum along segments of a tensor.
 //
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
 //
 // Computes a tensor such that
 // \\(output_i = \sum_j data_j\\) where sum is over `j` such
@@ -17588,7 +18488,7 @@
 //
 // Arguments:
 //
-//	segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+//	segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
 // first dimension.  Values should be sorted and can be repeated.
 //
 // Returns Has same shape as data, except for dimension 0 which
@@ -19505,8 +20405,9 @@
 
 // Computes the minimum along segments of a tensor.
 //
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
 //
 // Computes a tensor such that
 // \\(output_i = \min_j(data_j)\\) where `min` is over `j` such
@@ -19520,7 +20421,7 @@
 //
 // Arguments:
 //
-//	segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+//	segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
 // first dimension.  Values should be sorted and can be repeated.
 //
 // Returns Has same shape as data, except for dimension 0 which
@@ -19634,164 +20535,6 @@
 	return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights
 }
 
-// ShapeAttr is an optional argument to Shape.
-type ShapeAttr func(optionalAttr)
-
-// ShapeOutType sets the optional out_type attribute to value.
-// If not specified, defaults to DT_INT32
-func ShapeOutType(value tf.DataType) ShapeAttr {
-	return func(m optionalAttr) {
-		m["out_type"] = value
-	}
-}
-
-// Returns the shape of a tensor.
-//
-// This operation returns a 1-D integer tensor representing the shape of `input`.
-//
-// For example:
-//
-// ```
-// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
-// shape(t) ==> [2, 2, 3]
-// ```
-func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "Shape",
-		Input: []tf.Input{
-			input,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// Computes the power of one value to another.
-//
-// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
-// corresponding elements in `x` and `y`. For example:
-//
-// ```
-// # tensor 'x' is [[2, 2]], [3, 3]]
-// # tensor 'y' is [[8, 16], [2, 3]]
-// tf.pow(x, y) ==> [[256, 65536], [9, 27]]
-// ```
-func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "Pow",
-		Input: []tf.Input{
-			x, y,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// Computes fingerprints of the input strings.
-//
-// Arguments:
-//	input: vector of strings to compute fingerprints on.
-//
-// Returns a (N,2) shaped matrix where N is the number of elements in the input
-// vector. Each row contains the low and high parts of the fingerprint.
-func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "SdcaFprint",
-		Input: []tf.Input{
-			input,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// RandomPoissonV2Attr is an optional argument to RandomPoissonV2.
-type RandomPoissonV2Attr func(optionalAttr)
-
-// RandomPoissonV2Seed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed.  Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr {
-	return func(m optionalAttr) {
-		m["seed"] = value
-	}
-}
-
-// RandomPoissonV2Seed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr {
-	return func(m optionalAttr) {
-		m["seed2"] = value
-	}
-}
-
-// RandomPoissonV2Dtype sets the optional dtype attribute to value.
-// If not specified, defaults to DT_INT64
-func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr {
-	return func(m optionalAttr) {
-		m["dtype"] = value
-	}
-}
-
-// Outputs random values from the Poisson distribution(s) described by rate.
-//
-// This op uses two algorithms, depending on rate. If rate >= 10, then
-// the algorithm by Hormann is used to acquire samples via
-// transformation-rejection.
-// See http://www.sciencedirect.com/science/article/pii/0167668793909974.
-//
-// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
-// random variables.
-// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
-// Programming, Volume 2. Addison Wesley
-//
-// Arguments:
-//	shape: 1-D integer tensor. Shape of independent samples to draw from each
-// distribution described by the shape parameters given in rate.
-//	rate: A tensor in which each scalar is a "rate" parameter describing the
-// associated poisson distribution.
-//
-// Returns A tensor with shape `shape + shape(rate)`. Each slice
-// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
-// `rate[i0, i1, ...iN]`.
-func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "RandomPoissonV2",
-		Input: []tf.Input{
-			shape, rate,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // MatrixTriangularSolveAttr is an optional argument to MatrixTriangularSolve.
 type MatrixTriangularSolveAttr func(optionalAttr)
 
@@ -20266,27 +21009,31 @@
 
 // Computes the product along segments of a tensor.
 //
-// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+// for an explanation of segments.
 //
 // This operator is similar to the unsorted segment sum operator found
 // [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
 // Instead of computing the sum over segments, it computes the product of all
 // entries belonging to a segment such that:
 //
-// \\(output_i = \prod_j data_j\\) where the product is over `j` such
-// that `segment_ids[j] == i`.
+// \\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples
+// `j...` such that `segment_ids[j...] == i`.
 //
 // If there is no entry for a given segment ID `i`, it outputs 1.
 //
+// If the given segment ID `i` is negative, then the corresponding value is
+// dropped, and will not be included in the result.
+//
 // Arguments:
 //
-//	segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
-// first dimension.
+//	segment_ids: A tensor whose shape is a prefix of `data.shape`.
 //
 //
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
+// Returns Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
 func UnsortedSegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
 	if scope.Err() != nil {
 		return
@@ -20301,70 +21048,11 @@
 	return op.Output(0)
 }
 
-// RandomUniformIntAttr is an optional argument to RandomUniformInt.
-type RandomUniformIntAttr func(optionalAttr)
-
-// RandomUniformIntSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed.  Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomUniformIntSeed(value int64) RandomUniformIntAttr {
-	return func(m optionalAttr) {
-		m["seed"] = value
-	}
-}
-
-// RandomUniformIntSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomUniformIntSeed2(value int64) RandomUniformIntAttr {
-	return func(m optionalAttr) {
-		m["seed2"] = value
-	}
-}
-
-// Outputs random integers from a uniform distribution.
-//
-// The generated values are uniform integers in the range `[minval, maxval)`.
-// The lower bound `minval` is included in the range, while the upper bound
-// `maxval` is excluded.
-//
-// The random integers are slightly biased unless `maxval - minval` is an exact
-// power of two.  The bias is small for values of `maxval - minval` significantly
-// smaller than the range of the output (either `2^32` or `2^64`).
-//
-// Arguments:
-//	shape: The shape of the output tensor.
-//	minval: 0-D.  Inclusive lower bound on the generated integers.
-//	maxval: 0-D.  Exclusive upper bound on the generated integers.
-//
-// Returns A tensor of the specified shape filled with uniform random integers.
-func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "RandomUniformInt",
-		Input: []tf.Input{
-			shape, minval, maxval,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Computes the mean along sparse segments of a tensor.
 //
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
 //
 // Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
 // dimension, selecting a subset of dimension 0, specified by `indices`.
@@ -20433,8 +21121,9 @@
 // Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is
 // misisng, the `output` tensor at that position will be zeroed.
 //
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
 //
 // Arguments:
 //
@@ -20579,8 +21268,9 @@
 //
 // N is the size of the segment being reduced.
 //
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
 //
 // Arguments:
 //
@@ -20638,8 +21328,9 @@
 // Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is
 // misisng, the `output` tensor at that position will be zeroed.
 //
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
 //
 // Arguments:
 //
@@ -21000,8 +21691,9 @@
 
 // Computes the maximum along segments of a tensor.
 //
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
 //
 // Computes a tensor such that
 // \\(output_i = \max_j(data_j)\\) where `max` is over `j` such
@@ -21015,7 +21707,7 @@
 //
 // Arguments:
 //
-//	segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+//	segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
 // first dimension.  Values should be sorted and can be repeated.
 //
 // Returns Has same shape as data, except for dimension 0 which
@@ -23431,29 +24123,57 @@
 	return op.Output(0)
 }
 
-// Computes the matrix exponential of one or more square matrices:
+// Creates a Tensor by indexing into the TensorList.
+//
+// Each row in the produced Tensor corresponds to the element in the TensorList
+// specified by the given index (see `tf.gather`).
+//
+// input_handle: The input tensor list.
+// indices: The indices used to index into the list.
+// values: The tensor.
+func TensorListGather(scope *Scope, input_handle tf.Output, indices tf.Output, element_dtype tf.DataType) (values tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"element_dtype": element_dtype}
+	opspec := tf.OpSpec{
+		Type: "TensorListGather",
+		Input: []tf.Input{
+			input_handle, indices,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Creates a TensorList by indexing into a Tensor.
+//
+// Each member of the TensorList corresponds to one row of the input tensor,
+// specified by the given index (see `tf.gather`).
+//
+// tensor: The input tensor.
+// indices: The indices used to index into the list.
+// element_shape: The shape of the elements in the list (can be less specified than
+//   the shape of the tensor).
+// output_handle: The TensorList.
+func TensorListScatter(scope *Scope, tensor tf.Output, indices tf.Output, element_shape tf.Output) (output_handle tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "TensorListScatter",
+		Input: []tf.Input{
+			tensor, indices, element_shape,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Deprecated, use python implementation tf.linalg.matrix_exponential.
 //
 // DEPRECATED at GraphDef version 27: Use Python implementation tf.linalg.matrix_exponential instead.
-//
-// \\(exp(A) = \sum_{n=0}^\infty A^n/n!\\)
-//
-// The exponential is computed using a combination of the scaling and squaring
-// method and the Pade approximation. Details can be founds in:
-// Nicholas J. Higham, "The scaling and squaring method for the matrix exponential
-// revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
-//
-// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-// form square matrices. The output is a tensor of the same shape as the input
-// containing the exponential for all input submatrices `[..., :, :]`.
-//
-// Arguments:
-//	input: Shape is `[..., M, M]`.
-//
-// Returns Shape is `[..., M, M]`.
-//
-// @compatibility(scipy)
-// Equivalent to scipy.linalg.expm
-// @end_compatibility
 func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) {
 	if scope.Err() != nil {
 		return
@@ -23959,8 +24679,9 @@
 
 // Computes the product along segments of a tensor.
 //
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
 //
 // Computes a tensor such that
 // \\(output_i = \prod_j data_j\\) where the product is over `j` such
@@ -23974,7 +24695,7 @@
 //
 // Arguments:
 //
-//	segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+//	segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
 // first dimension.  Values should be sorted and can be repeated.
 //
 // Returns Has same shape as data, except for dimension 0 which
@@ -24999,7 +25720,7 @@
 
 // Update '*var' according to the Adam algorithm.
 //
-// $$lr_t := \text{learning_rate} * \sqrt{(1 - beta_2^t) / (1 - beta_1^t)}$$
+// $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
 // $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
 // $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
 // $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
@@ -27016,8 +27737,10 @@
 // If `len` defines a substring that would extend beyond the length of the input
 // string, then as many characters as possible are used.
 //
-// If `pos` is negative or specifies a character index larger than any of the input
-// strings, then an `InvalidArgumentError` is thrown.
+// A negative `pos` indicates distance within the string backwards from the end.
+//
+// If `pos` specifies an index which is out of range for any of the input strings,
+// then an `InvalidArgumentError` is thrown.
 //
 // `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on
 // Op creation.
@@ -27422,35 +28145,6 @@
 	return scope.AddOperation(opspec)
 }
 
-// Makes the summary of accumulated stats for the batch.
-//
-// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example.
-//
-// Arguments:
-//	node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer.
-//	gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients.
-//	hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians.
-//	bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column).
-//	max_splits: int; the maximum number of splits possible in the whole tree.
-//	num_buckets: int; equals to the maximum possible value of bucketized feature.
-//
-// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians.
-func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets}
-	opspec := tf.OpSpec{
-		Type: "BoostedTreesMakeStatsSummary",
-		Input: []tf.Input{
-			node_ids, gradients, hessians, tf.OutputList(bucketized_features_list),
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Adjust the contrast of one or more images.
 //
 // `images` is a tensor of at least 3 dimensions.  The last 3 dimensions are
@@ -27643,6 +28337,8 @@
 // On GPU, if an out of bound index is found, a 0 is stored in the
 // corresponding output value.
 //
+// See also `tf.batch_gather` and `tf.gather_nd`.
+//
 // Arguments:
 //	params: The tensor from which to gather values. Must be at least rank
 // `axis + 1`.
@@ -28153,6 +28849,30 @@
 	return op.Output(0)
 }
 
+// Identity transformation that models performance.
+//
+// Identity transformation that models performance.
+//
+// Arguments:
+//	input_dataset: A variant tensor representing the input dataset.
+//
+//
+func ModelDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+	opspec := tf.OpSpec{
+		Type: "ModelDataset",
+		Input: []tf.Input{
+			input_dataset,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Performs a padding as a preprocess during a convolution.
 //
 // Similar to FusedResizeAndPadConv2d, this op allows for an optimized
@@ -28842,10 +29562,16 @@
 //
 // Arguments:
 //
-//	window_size: A scalar representing the number of elements to accumulate in a window.
+//	size: A scalar representing the number of elements to accumulate in a window.
+//	shift: A scalar representing the steps moving the sliding window forward in one
+// iteration. It must be positive.
+//	stride: A scalar representing the stride of the input elements of the sliding window.
+// It must be positive.
+//	drop_remainder: A scalar representing whether a window should be dropped in case its size is
+// smaller than desired.
 //
 //
-func WindowDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+func WindowDataset(scope *Scope, input_dataset tf.Output, size tf.Output, shift tf.Output, stride tf.Output, drop_remainder tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
 	if scope.Err() != nil {
 		return
 	}
@@ -28853,7 +29579,7 @@
 	opspec := tf.OpSpec{
 		Type: "WindowDataset",
 		Input: []tf.Input{
-			input_dataset, window_size,
+			input_dataset, size, shift, stride, drop_remainder,
 		},
 		Attrs: attrs,
 	}
@@ -30063,7 +30789,7 @@
 //
 //     [1, 13, 3, 14, 14, 6, 7, 20]
 //
-// See @{tf.scatter_nd} for more details about how to make updates to slices.
+// See `tf.scatter_nd` for more details about how to make updates to slices.
 //
 // Arguments:
 //	input: A Tensor.
@@ -30680,6 +31406,41 @@
 	return op.Output(0)
 }
 
+// Generate the bucket boundaries for each feature based on accumulated summaries.
+//
+// An op that returns a list of float tensors for a quantile stream resource. Each
+// tensor is Rank 1 containing bucket boundaries for a single feature.
+//
+// Arguments:
+//	quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource.
+//	num_features: inferred int; number of features to get bucket boundaries for.
+//
+// Returns float; List of Rank 1 Tensors each containing the bucket boundaries for a feature.
+func BoostedTreesQuantileStreamResourceGetBucketBoundaries(scope *Scope, quantile_stream_resource_handle tf.Output, num_features int64) (bucket_boundaries []tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"num_features": num_features}
+	opspec := tf.OpSpec{
+		Type: "BoostedTreesQuantileStreamResourceGetBucketBoundaries",
+		Input: []tf.Input{
+			quantile_stream_resource_handle,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	if scope.Err() != nil {
+		return
+	}
+	var idx int
+	var err error
+	if bucket_boundaries, idx, err = makeOutputList(op, idx, "bucket_boundaries"); err != nil {
+		scope.UpdateErr("BoostedTreesQuantileStreamResourceGetBucketBoundaries", err)
+		return
+	}
+	return bucket_boundaries
+}
+
 // OrderedMapUnstageAttr is an optional argument to OrderedMapUnstage.
 type OrderedMapUnstageAttr func(optionalAttr)
 
@@ -30751,6 +31512,43 @@
 	return values
 }
 
+// BoostedTreesQuantileStreamResourceHandleOpAttr is an optional argument to BoostedTreesQuantileStreamResourceHandleOp.
+type BoostedTreesQuantileStreamResourceHandleOpAttr func(optionalAttr)
+
+// BoostedTreesQuantileStreamResourceHandleOpContainer sets the optional container attribute to value.
+// If not specified, defaults to ""
+func BoostedTreesQuantileStreamResourceHandleOpContainer(value string) BoostedTreesQuantileStreamResourceHandleOpAttr {
+	return func(m optionalAttr) {
+		m["container"] = value
+	}
+}
+
+// BoostedTreesQuantileStreamResourceHandleOpSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func BoostedTreesQuantileStreamResourceHandleOpSharedName(value string) BoostedTreesQuantileStreamResourceHandleOpAttr {
+	return func(m optionalAttr) {
+		m["shared_name"] = value
+	}
+}
+
+// Creates a handle to a BoostedTreesQuantileStreamResource.
+func BoostedTreesQuantileStreamResourceHandleOp(scope *Scope, optional ...BoostedTreesQuantileStreamResourceHandleOpAttr) (resource tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "BoostedTreesQuantileStreamResourceHandleOp",
+
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // OrderedMapSizeAttr is an optional argument to OrderedMapSize.
 type OrderedMapSizeAttr func(optionalAttr)
 
diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md
index c7382ff..7ef862a 100644
--- a/tensorflow/java/README.md
+++ b/tensorflow/java/README.md
@@ -10,7 +10,7 @@
 
 ## Quickstart
 
--   Refer to [Installing TensorFlow for Java](https://www.tensorflow.org/install/install_java)
+-   Refer to [Installing TensorFlow for Java](https://www.tensorflow.org/install/lang_java)
 -   [Javadoc](https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/package-summary)
 -   [![Maven Central](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/tensorflow/badge.svg)](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/tensorflow)
 
@@ -22,8 +22,7 @@
 1.  Install [bazel](https://www.bazel.build/versions/master/docs/install.html)
 
 2.  Setup the environment to build TensorFlow from source code
-    ([Linux](https://www.tensorflow.org/install/install_sources#PrepareLinux)
-    or [macOS](https://www.tensorflow.org/install/install_sources#PrepareMac)).
+    ([Linux or macOS](https://www.tensorflow.org/install/source)).
     If you'd like to skip reading those details and do not care about GPU
     support, try the following:
 
@@ -35,7 +34,7 @@
     brew install swig
     ```
 
-3.  [Configure](https://www.tensorflow.org/install/install_sources#configure_the_installation)
+3.  [Configure](https://www.tensorflow.org/install/source)
     (e.g., enable GPU support) and build:
 
     ```sh
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index f9093ce..6c82301 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
   <parent>
     <groupId>org.tensorflow</groupId>
     <artifactId>parentpom</artifactId>
-    <version>1.10.0</version>
+    <version>1.11.0-rc1</version>
     <relativePath>../</relativePath>
   </parent>
   <artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index 1208956..f763479 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
   <parent>
     <groupId>org.tensorflow</groupId>
     <artifactId>parentpom</artifactId>
-    <version>1.10.0</version>
+    <version>1.11.0-rc1</version>
     <relativePath>../</relativePath>
   </parent>
   <artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index 755449c..7fcc6ff 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
   <parent>
     <groupId>org.tensorflow</groupId>
     <artifactId>parentpom</artifactId>
-    <version>1.10.0</version>
+    <version>1.11.0-rc1</version>
     <relativePath>../</relativePath>
   </parent>
   <artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index e1bf2c7..689902e 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
   <modelVersion>4.0.0</modelVersion>
   <groupId>org.tensorflow</groupId>
   <artifactId>parentpom</artifactId>
-  <version>1.10.0</version>
+  <version>1.11.0-rc1</version>
   <packaging>pom</packaging>
 
   <url>https://www.tensorflow.org</url>
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index b89f042..ea1462a 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
   <parent>
     <groupId>org.tensorflow</groupId>
     <artifactId>parentpom</artifactId>
-    <version>1.10.0</version>
+    <version>1.11.0-rc1</version>
     <relativePath>../</relativePath>
   </parent>
   <artifactId>proto</artifactId>
diff --git a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
index 1b7995b..ce1ebfa 100644
--- a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
+++ b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
@@ -6,7 +6,7 @@
     <groupId>org.tensorflow</groupId>
     <artifactId>spark-tensorflow-connector_2.11</artifactId>
     <packaging>jar</packaging>
-    <version>1.10.0</version>
+    <version>1.11.0-rc1</version>
     <name>spark-tensorflow-connector</name>
     <url>https://www.tensorflow.org</url>
     <description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
diff --git a/tensorflow/java/maven/tensorflow-hadoop/pom.xml b/tensorflow/java/maven/tensorflow-hadoop/pom.xml
index 0fe6f4d..56346fd 100644
--- a/tensorflow/java/maven/tensorflow-hadoop/pom.xml
+++ b/tensorflow/java/maven/tensorflow-hadoop/pom.xml
@@ -5,7 +5,7 @@
     <groupId>org.tensorflow</groupId>
     <artifactId>tensorflow-hadoop</artifactId>
     <packaging>jar</packaging>
-    <version>1.10.0</version>
+    <version>1.11.0-rc1</version>
     <name>tensorflow-hadoop</name>
     <url>https://www.tensorflow.org</url>
     <description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index 0de9024..93decea 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
   <parent>
     <groupId>org.tensorflow</groupId>
     <artifactId>parentpom</artifactId>
-    <version>1.10.0</version>
+    <version>1.11.0-rc1</version>
     <relativePath>../</relativePath>
   </parent>
   <artifactId>tensorflow</artifactId>
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 2dc2808..79f1446 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1998,6 +1998,29 @@
 )
 
 py_library(
+    name = "while_v2",
+    srcs = [
+        "ops/while_v2.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":array_ops",
+        ":cond_v2_impl",
+        ":constant_op",
+        ":control_flow_util",
+        ":framework_ops",
+        ":function_def_to_graph",
+        ":functional_ops_gen",
+        ":gradients_impl",
+        ":list_ops",
+        ":tensor_shape",
+        ":util",
+        "//tensorflow/core:protos_all_py",
+        "//tensorflow/python/eager:function",
+    ],
+)
+
+py_library(
     name = "cond_v2_impl",
     srcs = [
         "ops/cond_v2_impl.py",
@@ -2301,6 +2324,8 @@
     deps = [
         ":framework_for_generated_wrappers",
         ":logging_ops_gen",
+        ":platform",
+        ":string_ops",
         ":util",
     ],
 )
@@ -3090,7 +3115,7 @@
 
 cuda_py_test(
     name = "image_grad_test",
-    size = "small",
+    size = "medium",
     srcs = ["ops/image_grad_test.py"],
     additional_deps = [
         ":client_testlib",
@@ -3738,6 +3763,19 @@
     ],
 )
 
+cc_library(
+    name = "session_ref",
+    srcs = ["client/session_ref.cc"],
+    hdrs = ["client/session_ref.h"],
+    deps = [
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:master_proto_cc",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:replay_log_proto_cc",
+    ],
+)
+
 tf_cuda_library(
     name = "tf_session_helper",
     srcs = ["client/tf_session_helper.cc"],
@@ -3748,6 +3786,7 @@
         ":ndarray_tensor_bridge",
         ":numpy_lib",
         ":safe_ptr",
+        ":session_ref",
         ":test_ops_kernels",
         "//tensorflow/c:c_api",
         "//tensorflow/c:c_api_internal",
@@ -3760,7 +3799,6 @@
         "//tensorflow/core:graph",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
-        "//tensorflow/core:session_ref",
         "//third_party/py/numpy:headers",
         "//third_party/python_runtime:headers",
     ],
diff --git a/tensorflow/python/autograph/README.md b/tensorflow/python/autograph/README.md
index cc54da4..bfe21b4 100644
--- a/tensorflow/python/autograph/README.md
+++ b/tensorflow/python/autograph/README.md
@@ -65,7 +65,7 @@
 Then import the `autograph` module from `tf.contrib`:
 
 ```
-from tensorflow.contrib import autograph as ag
+from tensorflow.python import autograph as ag
 ```
 
 ### Related links
diff --git a/tensorflow/python/autograph/__init__.py b/tensorflow/python/autograph/__init__.py
index c3448e6..5ed5e85 100644
--- a/tensorflow/python/autograph/__init__.py
+++ b/tensorflow/python/autograph/__init__.py
@@ -27,6 +27,7 @@
 from tensorflow.python.autograph.core.errors import GraphConstructionError
 from tensorflow.python.autograph.core.errors import TfRuntimeError
 from tensorflow.python.autograph.core.errors import improved_errors
+from tensorflow.python.autograph.impl.api import ConversionOptions
 from tensorflow.python.autograph.impl.api import RunMode
 from tensorflow.python.autograph.impl.api import convert
 from tensorflow.python.autograph.impl.api import converted_call
@@ -42,6 +43,7 @@
 
 _allowed_symbols = [
     # Main API
+    'ConversionOptions',
     'RunMode',
     'convert',
     'converted_call',
diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py
index 6a606c4..fc2075b 100644
--- a/tensorflow/python/autograph/converters/call_trees.py
+++ b/tensorflow/python/autograph/converters/call_trees.py
@@ -238,9 +238,16 @@
     # Before we could convert all the time though, we'd need a reasonable
     # caching mechanism.
     template = """
-      ag__.converted_call(func, True, False, False, {}, args)
+      ag__.converted_call(
+          func,
+          ag__.ConversionOptions.new(recursive=recursive_val),
+          args)
     """
-    call_expr = templates.replace(template, func=node.func, args=node.args)
+    call_expr = templates.replace(
+        template,
+        func=node.func,
+        recursive_val=parser.parse_expression(str(self.ctx.program.recursive)),
+        args=node.args)
     new_call = call_expr[0].value
     # TODO(mdan): Improve the template mechanism to better support this.
     new_call.keywords = node.keywords
diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py
index 7b3905f..80928ae 100644
--- a/tensorflow/python/autograph/core/converter.py
+++ b/tensorflow/python/autograph/core/converter.py
@@ -63,10 +63,8 @@
 from __future__ import division
 from __future__ import print_function
 
-import collections
 from enum import Enum
 
-
 from tensorflow.python.autograph.core import config
 from tensorflow.python.autograph.core import naming
 from tensorflow.python.autograph.pyct import anno
@@ -129,9 +127,8 @@
     self.autograph_module = autograph_module
     self.uncompiled_modules = uncompiled_modules
 
-    # Required to output dependencies in discovery order, which should match
-    # the reverse dependency order.
-    self.dependency_cache = collections.OrderedDict()
+    self.conversion_order = []
+    self.dependency_cache = {}
     self.additional_imports = set()
     self.name_map = {}
 
@@ -177,6 +174,7 @@
         self.name_map[o] = name
 
   def add_to_cache(self, original_entity, converted_ast):
+    self.conversion_order.append(original_entity)
     self.dependency_cache[original_entity] = converted_ast
 
 
diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
index 0a0c6f9..7ce1b7c 100644
--- a/tensorflow/python/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -93,11 +93,21 @@
       self.dynamic_calls.append(args)
       return 7
 
+    class ConversionOptions(object):
+      """Mock version of api.ConversionOptions."""
+
+      def __init__(self, recursive):
+        self.recursive = recursive
+
+      @classmethod
+      def new(cls, recursive):
+        cls(recursive)
+
     try:
       result, source = compiler.ast_to_object(node, include_source_map=True)
 
       result.tf = self.make_fake_mod('fake_tf', *symbols)
-      fake_ag = self.make_fake_mod('fake_ag', converted_call)
+      fake_ag = self.make_fake_mod('fake_ag', converted_call, ConversionOptions)
       fake_ag.__dict__.update(operators.__dict__)
       fake_ag.__dict__['utils'] = utils
       fake_ag.__dict__['rewrite_graph_construction_error'] = (
diff --git a/tensorflow/python/autograph/core/errors.py b/tensorflow/python/autograph/core/errors.py
index 0750353..23f8c5b 100644
--- a/tensorflow/python/autograph/core/errors.py
+++ b/tensorflow/python/autograph/core/errors.py
@@ -208,7 +208,6 @@
   """
   try:
     cleaned_traceback = _cut_traceback_loops(source_map, error.op.traceback)
-    # cleaned_traceback = error.op.traceback
     cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback)
 
     op_name = error.op.name
diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
index 669d36b..1dc97d2 100644
--- a/tensorflow/python/autograph/impl/api.py
+++ b/tensorflow/python/autograph/impl/api.py
@@ -18,7 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-from functools import wraps
+import collections
+import functools
 
 from enum import Enum
 
@@ -38,6 +39,41 @@
 # (currently we require (module + class name, type))
 
 
+class ConversionOptions(
+    collections.namedtuple('ConversionOptions',
+                           ('recursive', 'verbose', 'strip_decorators',
+                            'force_conversion', 'arg_types'))):
+  """Container for conversion flags.
+
+  Attributes:
+    recursive: bool, whether to recursively convert any user functions or
+        classes that the converted function may use.
+    verbose: bool, whether to log the compiled code.
+    strip_decorators: Tuple[Callable], contains decorators that should be in
+        excluded from the compiled output. By default, when converting a
+        function before the decorators are applied, the compiled output will
+        include those decorators.
+    force_conversion: bool, whether to force convertinng the target entity.
+        When force_conversion is turned off, the converter may decide to
+        return the function as-is.
+    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
+        function arguments.
+  """
+
+  @classmethod
+  def new(cls,
+          recursive=False,
+          verbose=False,
+          strip_decorators=None,
+          force_conversion=False,
+          arg_types=None):
+    return cls(recursive=recursive,
+               verbose=verbose,
+               strip_decorators=strip_decorators or (),
+               force_conversion=force_conversion,
+               arg_types=arg_types or {})
+
+
 # TODO(mdan): This should behave like to_graph (e.g. convert statically).
 def convert(recursive=False, verbose=False):
   """Decorator that compiles a function to use TensorFlow ops.
@@ -59,9 +95,15 @@
   def decorator(f):
     """Decorator implementation."""
 
-    @wraps(f)
+    @functools.wraps(f)
     def wrapper(*args, **kwargs):
-      return converted_call(f, recursive, verbose, True, {}, *args, **kwargs)
+      return converted_call(
+          f,
+          ConversionOptions.new(
+              recursive=recursive,
+              verbose=verbose,
+              force_conversion=True,
+          ), *args, **kwargs)
 
     wrapper = tf_decorator.make_decorator(f, wrapper)
 
@@ -107,11 +149,11 @@
   def decorator(f):
     """Decorator implementation."""
 
-    @wraps(f)
+    @functools.wraps(f)
     def graph_wrapper(*args, **kwargs):
       return f(*args, **kwargs)
 
-    @wraps(f)
+    @functools.wraps(f)
     def py_func_wrapper(*args, **kwargs):
       if kwargs:
         raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
@@ -135,12 +177,11 @@
 
 
 # TODO(mdan): Move to a private, undocumented module.
-def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
-                   **kwargs):
+def converted_call(f, options, *args, **kwargs):
   """Compiles a function call inline. For internal use only."""
   # TODO(mdan): This needs cleanup.
   # In particular, we may want to avoid renaming functions altogether.
-  if not force_conversion and conversion.is_whitelisted_for_graph(f):
+  if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
     return f(*args, **kwargs)
 
   unknown_arg_value = object()  # Sentinel for arguments of unknown value
@@ -183,8 +224,8 @@
       continue
     arg_class = arg.__class__
     # If arg_value_hints specifies any name, use that instead.
-    if name not in arg_types:
-      arg_types[name] = (arg_class.__name__, arg_class)
+    if name not in options.arg_types:
+      options.arg_types[name] = (arg_class.__name__, arg_class)
 
   # When called from within a decorator, this is the only indication that
   # the function is a method - it appears that the decorator is applied
@@ -199,23 +240,25 @@
 
   converted_f = to_graph(
       target_entity,
-      recursive=recursive,
-      verbose=verbose,
+      recursive=options.recursive,
+      verbose=options.verbose,
       arg_values=arg_values,
-      arg_types=arg_types,
-      partial_types=partial_types)
+      arg_types=options.arg_types,
+      partial_types=partial_types,
+      strip_decorators=options.strip_decorators)
   return converted_f(*effective_args, **kwargs)
 
 
 # TODO(mdan): Rename: to_ops?
-# TODO(mdan): Looki into overloading as function and decorator, like tfe.defun.
+# TODO(mdan): Look into overloading as function and decorator, like tfe.defun?
 # TODO(mdan): Remove partial_types.
 def to_graph(e,
              recursive=True,
              verbose=False,
              arg_values=None,
              arg_types=None,
-             partial_types=None):
+             partial_types=None,
+             strip_decorators=None):
   """Converts a Python entity into equivalent code that uses TensorFlow ops.
 
   Supported Python entities include:
@@ -234,6 +277,8 @@
     arg_types: Optional[Dict[Text, Type]], type hints for symbols including
         function arguments.
     partial_types: Set[Type], reserved for internal use.
+    strip_decorators: Tuple[Callable], same as
+        ConversionOptions.strip_decorators.
 
   Returns:
     Union[Callable, Type], the converted entity, which is the same kind as e
@@ -243,9 +288,13 @@
   Raises:
     ValueError: If the entity could not be converted.
   """
+  if strip_decorators is None:
+    strip_decorators = ()
+  strip_decorators += (convert, do_not_convert, converted_call)
+
   program_ctx = converter.ProgramContext(
       recursive=recursive,
-      autograph_decorators=(convert, do_not_convert, converted_call),
+      autograph_decorators=strip_decorators,
       partial_types=partial_types,
       autograph_module=tf_inspect.getmodule(to_graph),
       uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
@@ -253,8 +302,9 @@
                                                   arg_types)
 
   nodes = []
-  for dep in reversed(tuple(program_ctx.dependency_cache.values())):
-    nodes.extend(dep)
+  for dep in reversed(program_ctx.conversion_order):
+    nodes.extend(program_ctx.dependency_cache[dep])
+
   compiled_module, compiled_src = compiler.ast_to_object(
       nodes,
       source_prefix=program_ctx.required_imports,
@@ -322,7 +372,7 @@
   conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)
 
   code = '\n'.join(
-      compiler.ast_to_source(dep, indentation)
-      for dep in reversed(tuple(program_ctx.dependency_cache.values())))
+      compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
+      for dep in reversed(program_ctx.conversion_order))
 
   return program_ctx.required_imports + '\n\n' + code
diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
index 54e12f0..e0770ef 100644
--- a/tensorflow/python/autograph/impl/api_test.py
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -32,7 +32,6 @@
 
 tf = utils.fake_tf()
 
-
 class ApiTest(test.TestCase):
 
   def setUp(self):
@@ -180,8 +179,9 @@
       @api.convert(recursive=True)
       def test_method(self, x, s, a):
         while tf.reduce_sum(x) > s:
-          x //= api.converted_call(self.called_member, False, False, False, {},
-                                   self, a)
+          x //= api.converted_call(
+              self.called_member,
+              api.ConversionOptions.new(), self, a)
         return x
 
     tc = TestClass()
@@ -192,7 +192,7 @@
       self.assertListEqual([0, 1], sess.run(x).tolist())
 
   def test_converted_call_builtin(self):
-    x = api.converted_call(range, False, False, False, {}, 3)
+    x = api.converted_call(range, api.ConversionOptions.new(), 3)
     self.assertEqual((0, 1, 2), tuple(x))
 
   def test_converted_call_function(self):
@@ -203,7 +203,7 @@
       return x
 
     with self.test_session() as sess:
-      x = api.converted_call(test_fn, False, False, False, {},
+      x = api.converted_call(test_fn, api.ConversionOptions.new(),
                              constant_op.constant(-1))
       self.assertEqual(1, sess.run(x))
 
@@ -221,7 +221,7 @@
 
     with self.test_session() as sess:
       tc = TestClass(constant_op.constant(-1))
-      x = api.converted_call(tc.test_method, False, False, False, {}, tc)
+      x = api.converted_call(tc.test_method, api.ConversionOptions.new(), tc)
       self.assertEqual(1, sess.run(x))
 
   def test_converted_call_method_by_class(self):
@@ -238,7 +238,9 @@
 
     with self.test_session() as sess:
       tc = TestClass(constant_op.constant(-1))
-      x = api.converted_call(TestClass.test_method, False, False, False, {}, tc)
+      x = api.converted_call(
+          TestClass.test_method,
+          api.ConversionOptions.new(), tc)
       self.assertEqual(1, sess.run(x))
 
   def test_converted_call_callable_object(self):
@@ -255,7 +257,7 @@
 
     with self.test_session() as sess:
       tc = TestClass(constant_op.constant(-1))
-      x = api.converted_call(tc, False, False, False, {})
+      x = api.converted_call(tc, api.ConversionOptions.new())
       self.assertEqual(1, sess.run(x))
 
   def test_converted_call_constructor(self):
@@ -271,7 +273,7 @@
         return self.x
 
     with self.test_session() as sess:
-      tc = api.converted_call(TestClass, False, False, False, {},
+      tc = api.converted_call(TestClass, api.ConversionOptions.new(),
                               constant_op.constant(-1))
       # tc is now a converted object.
       x = tc.test_method()
@@ -283,12 +285,12 @@
       return x == 0
 
     with self.test_session() as sess:
-      x = api.converted_call(f, False, False, False, {},
+      x = api.converted_call(f, api.ConversionOptions.new(),
                              constant_op.constant(0))
       self.assertTrue(sess.run(x))
 
       converted_f = api.to_graph(f)
-      x = api.converted_call(converted_f, False, False, False, {},
+      x = api.converted_call(converted_f, api.ConversionOptions.new(),
                              constant_op.constant(0))
       self.assertTrue(sess.run(x))
 
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index 928ff9e..a0d13c8 100644
--- a/tensorflow/python/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -255,6 +255,7 @@
     # internal modules.
     ag_internal = imp.new_module('autograph')
     ag_internal.converted_call = autograph_module.converted_call
+    ag_internal.ConversionOptions = autograph_module.ConversionOptions
     ag_internal.utils = utils
     ag_internal.rewrite_graph_construction_error = (
         errors.rewrite_graph_construction_error)
diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
index 1433f9a..fca0eb6 100644
--- a/tensorflow/python/autograph/pyct/cfg.py
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -27,6 +27,7 @@
 from __future__ import print_function
 
 import collections
+import weakref
 from enum import Enum
 
 # pylint:disable=g-bad-import-order
@@ -61,7 +62,10 @@
 
   def freeze(self):
     self.next = frozenset(self.next)
-    self.prev = frozenset(self.prev)
+    # Assumption: All CFG nodes have identical life spans, because the graph
+    # owns them. Nodes should never be used outside the context of an existing
+    # graph.
+    self.prev = weakref.WeakSet(self.prev)
 
   def __repr__(self):
     if isinstance(self.ast_node, gast.FunctionDef):
@@ -256,7 +260,7 @@
     """Resets the state of this factory."""
     self.head = None
     self.errors = set()
-    self.node_index = collections.OrderedDict()
+    self.node_index = {}
 
     # TODO(mdan): Too many primitives. Use classes.
     self.leaves = set()
@@ -309,7 +313,10 @@
     """Grows the graph by adding a CFG node following the current leaves."""
     if ast_node is self.node_index:
       raise ValueError('%s added twice' % ast_node)
-    node = Node(next_=set(), prev=set(), ast_node=ast_node)
+    # Assumption: All CFG nodes have identical life spans, because the graph
+    # owns them. Nodes should never be used outside the context of an existing
+    # graph.
+    node = Node(next_=set(), prev=weakref.WeakSet(), ast_node=ast_node)
     self.node_index[ast_node] = node
     self.owners[node] = frozenset(self.active_stmts)
 
diff --git a/tensorflow/python/autograph/pyct/parser.py b/tensorflow/python/autograph/pyct/parser.py
index 112ed46..6368635 100644
--- a/tensorflow/python/autograph/pyct/parser.py
+++ b/tensorflow/python/autograph/pyct/parser.py
@@ -31,8 +31,21 @@
 def parse_entity(entity):
   """Returns the AST of given entity."""
   source = tf_inspect.getsource(entity)
+  # Comments and multiline strings can appear at arbitrary indentation levels,
+  # causing textwrap.dedent to not correctly dedent source code.
+  # TODO(b/115884650): Automatic handling of comments/multiline strings.
   source = textwrap.dedent(source)
-  return parse_str(source), source
+  try:
+    return parse_str(source), source
+  except IndentationError:
+    # Because we are parsing the source code of entities that have already
+    # successfully parsed once, any IndentationErrors are guaranteed to be
+    # caused by insufficient dedenting.
+    raise ValueError(
+        'Failed to dedent prior to parsing source code. If you have comments '
+        'or multiline strings in your code, try indenting them. '
+        'Multiline strings can be rewritten using textwrap.dedent.\n'
+        'Offending source code: \n %s' % source)
 
 
 def parse_str(src):
diff --git a/tensorflow/python/autograph/pyct/parser_test.py b/tensorflow/python/autograph/pyct/parser_test.py
index d0b465e..d3a7b7a 100644
--- a/tensorflow/python/autograph/pyct/parser_test.py
+++ b/tensorflow/python/autograph/pyct/parser_test.py
@@ -42,6 +42,22 @@
     """))
     self.assertEqual('f', mod.body[0].name)
 
+  def test_parse_comments(self):
+    def f():
+# unindented comment
+      pass
+    with self.assertRaises(ValueError):
+      parser.parse_entity(f)
+
+  def test_parse_multiline_strings(self):
+    def f():
+      print("""
+some
+multiline
+string""")
+    with self.assertRaises(ValueError):
+      parser.parse_entity(f)
+
   def test_parse_expression(self):
     node = parser.parse_expression('a.b')
     self.assertEqual('a', node.value.id)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index 9cb5991..086eda7 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -22,6 +22,7 @@
 from __future__ import print_function
 
 import copy
+import weakref
 
 import gast
 
@@ -126,7 +127,10 @@
       self.parent.mark_read(name)
 
   def mark_param(self, name, owner):
-    self.params[name] = owner
+    # Assumption: all AST nodes have the same life span. This lets us use
+    # a weak reference to mark the connection between a symbol node and the
+    # function node whose argument that symbol is.
+    self.params[name] = weakref.ref(owner)
 
   def mark_creation(self, name, writes_create_symbol=False):
     """Mark a qualified name as created."""
diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values.py b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
index 48b442f..36b9e70 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
@@ -29,10 +29,11 @@
 from tensorflow.python.autograph.pyct import transformer
 from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
 
+
 # TODO(aqj): Do we need this? Do other builtins fail in similar ways
 # See b/114389775 for a related bug in pyct
 # These symbols are legal in Python, but don't appear in the namespace.
-_special_symbols = {'range': range}
+_SPECIAL_SYMBOLS = {'range': range, 'print': print}
 
 
 class LiveValueResolver(transformer.Base):
@@ -71,8 +72,10 @@
             # If the symbol value is for example a primitive, then it will not
             # have a name.
             pass
-        elif node.id in _special_symbols:
-          anno.setanno(node, 'live_val', _special_symbols[node.id])
+        elif node.id in _SPECIAL_SYMBOLS:
+          # Note: if the user redefined any of these symbols, then they would
+          # be visible in the namespace and we would never reach this branch.
+          anno.setanno(node, 'live_val', _SPECIAL_SYMBOLS[node.id])
         else:
           pass
           # TODO(mdan): Should we raise an error here?
@@ -86,7 +89,8 @@
 
       if has_single_def:
         def_, = defs
-        if def_.param_of is self.enclosing_entities[0]:
+        # Note: param_of is a weakref.
+        if def_.param_of and def_.param_of() is self.enclosing_entities[0]:
           if node.id in self.entity_info.arg_values:
             obj = self.entity_info.arg_values[node.id]
             anno.setanno(node, 'live_val', obj)
diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
index 68c2a35..1bf0515 100644
--- a/tensorflow/python/autograph/pyct/templates.py
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -109,6 +109,7 @@
     if not node.ctx:
       raise ValueError('node %s is missing ctx value' % node)
 
+  # TODO(mdan): Rewrite _check and _set using a separate transformer.
   def _check_inner_children_have_context(self, node):
     if isinstance(node, gast.Attribute):
       self._check_inner_children_have_context(node.value)
@@ -131,6 +132,11 @@
         self._check_inner_children_have_context(node.upper)
       if node.step:
         self._check_inner_children_have_context(node.step)
+    elif isinstance(node, gast.BinOp):
+      self._check_inner_children_have_context(node.left)
+      self._check_inner_children_have_context(node.right)
+    elif isinstance(node, gast.UnaryOp):
+      self._check_inner_children_have_context(node.operand)
     elif isinstance(node, gast.Name):
       self._check_has_context(node)
     elif isinstance(node, (gast.Str, gast.Num)):
@@ -166,6 +172,11 @@
     elif isinstance(node, gast.Subscript):
       self._set_inner_child_context(node.value, ctx)
       self._check_inner_children_have_context(node.slice)
+    elif isinstance(node, gast.BinOp):
+      self._check_inner_children_have_context(node.left)
+      self._check_inner_children_have_context(node.right)
+    elif isinstance(node, gast.UnaryOp):
+      self._check_inner_children_have_context(node.operand)
     elif isinstance(node, (gast.Str, gast.Num)):
       pass
     else:
diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py
index 66268cf..078d9a1 100644
--- a/tensorflow/python/autograph/pyct/templates_test.py
+++ b/tensorflow/python/autograph/pyct/templates_test.py
@@ -132,6 +132,18 @@
     self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
     self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
 
+  def test_replace_expression_context(self):
+    template = """
+      def test_fn(foo):
+        foo
+    """
+
+    node = templates.replace(
+        template, foo=parser.parse_expression('a + 2 * b / -c'))[0]
+    self.assertIsInstance(node.body[0].ctx, gast.Load)
+    self.assertIsInstance(node.body[0].left.ctx, gast.Load)
+    self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load)
+
   def test_replace_complex_context(self):
     template = """
       def test_fn(foo):
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index ae0ad27..c963cfd 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -178,16 +178,30 @@
     feed_function_for_partial_run: A callable for specifying tensor values to
       feed when setting up a partial run, which takes a `tensor_type` type
       object as input, and returns a list of Tensors.
+
+  Raises:
+    ValueError: If `tensor_type` has already been registered.
   """
   for conversion_function in _REGISTERED_EXPANSIONS:
     if issubclass(conversion_function[0], tensor_type):
-      raise ValueError('%s has already been registered so ignore it.',
+      raise ValueError('%s has already been registered so ignore it.' %
                        tensor_type)
-      return
+
   _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function,
                                     feed_function_for_partial_run))
 
 
+def _is_attrs_instance(obj):
+  """Returns True if the given obj is an instance of attrs-decorated class."""
+  return getattr(obj.__class__, '__attrs_attrs__', None) is not None
+
+
+def _get_attrs_values(obj):
+  """Returns the list of values from an attrs instance."""
+  attrs = getattr(obj.__class__, '__attrs_attrs__')
+  return [getattr(obj, a.name) for a in attrs]
+
+
 class _FetchMapper(object):
   """Definition of the interface provided by fetch mappers.
 
@@ -247,6 +261,8 @@
       return _ListFetchMapper(fetch)
     elif isinstance(fetch, collections.Mapping):
       return _DictFetchMapper(fetch)
+    elif _is_attrs_instance(fetch):
+      return _AttrsFetchMapper(fetch)
     else:
       # Look for a handler in the registered expansions.
       for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS:
@@ -398,6 +414,32 @@
     return results
 
 
+class _AttrsFetchMapper(_FetchMapper):
+  """Fetch mapper for attrs decorated classes."""
+
+  def __init__(self, fetches):
+    """Creates a _AttrsFetchMapper.
+
+    Args:
+      fetches: An instance of an attrs decorated class.
+    """
+    values = _get_attrs_values(fetches)
+    self._fetch_type = type(fetches)
+    self._mappers = [
+        _FetchMapper.for_fetch(fetch) for fetch in values
+    ]
+    self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
+
+  def unique_fetches(self):
+    return self._unique_fetches
+
+  def build_results(self, values):
+    results = []
+    for m, vi in zip(self._mappers, self._value_indices):
+      results.append(m.build_results([values[j] for j in vi]))
+    return self._fetch_type(*results)
+
+
 class _FetchHandler(object):
   """Handler for structured fetches.
 
diff --git a/tensorflow/python/client/session_ref.cc b/tensorflow/python/client/session_ref.cc
new file mode 100644
index 0000000..b2300df
--- /dev/null
+++ b/tensorflow/python/client/session_ref.cc
@@ -0,0 +1,515 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/python/client/session_ref.h"
+
+#include <stdlib.h>
+#include <memory>
+#include <utility>
+
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/protobuf/named_tensor.pb.h"
+#include "tensorflow/core/protobuf/replay_log.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Scope helper to track active calls and manage session lifetime.
+// SessionRef blocks closing until all active calls complete or are cancelled.
+struct RunCounter {
+  std::shared_ptr<Session> session;
+  uint64* value;
+  mutex* m;
+  condition_variable* cv;
+
+  explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
+                      condition_variable* cv)
+      : session(std::move(s)), value(v), m(m), cv(cv) {
+    mutex_lock l(*m);
+    ++*value;
+  }
+
+  ~RunCounter() {
+    mutex_lock l(*m);
+    if (--*value == 0) {
+      cv->notify_all();
+    }
+  }
+};
+
+std::string SessionToHandle(Session* session) {
+  return strings::Printf("%llu", reinterpret_cast<uint64>(session));
+}
+
+// The Session interface has many methods of the form:
+//
+// X(a, b);
+// X(RunOptions, a, b);
+//
+// Not all sessions support the second case (with an empty RunOptions()).
+// We use this variable as a sentinel to dispatch to the correct call.
+RunOptions* kEmptyRunOptions() {
+  static RunOptions* options = new RunOptions();
+  return options;
+}
+
+}  // namespace
+
+// Run the given session operation, recording start and end timestamps.
+// If the operation returns a bad status, return after flushing the current
+// log request.  This should be run _after_ all request information has been
+// added to the current op.
+#define RUN_WITH_TIMESTAMP(OpName, ...)              \
+  op.set_start_time_us(Env::Default()->NowMicros()); \
+  Status status = session->OpName(__VA_ARGS__);      \
+  op.set_end_time_us(Env::Default()->NowMicros());   \
+  if (!status.ok()) {                                \
+    Flush(op).IgnoreError();                         \
+    return status;                                   \
+  }
+
+// Records requests (and optionally responses) performed against a session.
+// The resulting replay log can be used with the `tf_replay` tool to replicate
+// the operations against a simulated environment, without requiring the
+// original code or cluster setup.
+//
+// Session logging by setting the TF_REPLAY_LOG_FILE environment variable.
+class SessionLogger {
+ public:
+  SessionLogger() {
+    std::string log_name = getenv("TF_REPLAY_LOG_FILE");
+    TF_CHECK_OK(
+        Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
+    Env::Default()->DeleteFile(log_name).IgnoreError();
+    TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
+
+    log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
+  }
+
+  Status RecordCreateSession(Session* session) {
+    LOG(INFO) << "Capturing devices for session.";
+    ReplayOp op;
+    NewReplaySession* req = op.mutable_new_replay_session();
+
+    std::vector<DeviceAttributes> devices;
+    TF_CHECK_OK(session->ListDevices(&devices));
+    for (const DeviceAttributes& dev : devices) {
+      *req->mutable_devices()->add_local_device() = dev;
+    }
+
+    req->set_session_handle(SessionToHandle(session));
+    return Flush(op);
+  }
+
+  Status RecordRun(Session* session,
+                   const std::vector<std::pair<string, Tensor> >& inputs,
+                   const std::vector<string>& output_tensor_names,
+                   const std::vector<string>& target_node_names,
+                   std::vector<Tensor>* outputs) {
+    return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names,
+                     target_node_names, outputs, nullptr);
+  }
+
+  Status RecordRun(Session* session, const RunOptions& run_options,
+                   const std::vector<std::pair<string, Tensor> >& inputs,
+                   const std::vector<string>& output_tensor_names,
+                   const std::vector<string>& target_node_names,
+                   std::vector<Tensor>* outputs, RunMetadata* run_metadata) {
+    ReplayOp op;
+    RunStepRequest* req = op.mutable_run_step();
+    RunStepResponse* resp = op.mutable_run_step_response();
+
+    req->set_session_handle(SessionToHandle(session));
+    *req->mutable_options() = run_options;
+
+    for (const auto& it : inputs) {
+      NamedTensorProto* feed = req->add_feed();
+      feed->set_name(it.first);
+      it.second.AsProtoField(feed->mutable_tensor());
+    }
+
+    // Build an index from fetch tensor name to first index in
+    // output_tensor_names.
+    std::unordered_map<string, int> output_name_to_offset;
+    for (int i = 0; i < output_tensor_names.size(); ++i) {
+      const string& name = output_tensor_names[i];
+      if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
+        req->add_fetch(name);
+      }
+    }
+    for (const string& target : target_node_names) {
+      req->add_target(target);
+    }
+
+    if (&run_options == kEmptyRunOptions()) {
+      RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names,
+                         outputs);
+    } else {
+      RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names,
+                         target_node_names, outputs, run_metadata);
+    }
+
+    for (size_t i = 0; i < outputs->size(); ++i) {
+      const Tensor& tensor = (*outputs)[i];
+      NamedTensorProto* tproto = resp->add_tensor();
+      tensor.AsProtoField(tproto->mutable_tensor());
+      tproto->set_name(output_tensor_names[i]);
+    }
+
+    if (run_metadata) {
+      *resp->mutable_metadata() = *run_metadata;
+    }
+
+    return Flush(op);
+  }
+
+  Status RecordCreate(Session* session, const GraphDef& graph) {
+    return RecordCreate(session, *kEmptyRunOptions(), graph);
+  }
+
+  // N.B. RunOptions is not stored (it has no entry in CreateRequest)
+  Status RecordCreate(Session* session, const RunOptions& run_options,
+                      const GraphDef& graph) {
+    ReplayOp op;
+    CreateSessionRequest* req = op.mutable_create_session();
+    *req->mutable_graph_def() = graph;
+
+    CreateSessionResponse* resp = op.mutable_create_session_response();
+    if (&run_options == kEmptyRunOptions()) {
+      RUN_WITH_TIMESTAMP(Create, graph);
+    } else {
+      RUN_WITH_TIMESTAMP(Create, run_options, graph);
+    }
+    resp->set_session_handle(SessionToHandle(session));
+    return Flush(op);
+  }
+
+  Status RecordExtend(Session* session, const GraphDef& graph) {
+    return RecordExtend(session, *kEmptyRunOptions(), graph);
+  }
+
+  // N.B. RunOptions is not stored (it has no entry in ExtendRequest)
+  Status RecordExtend(Session* session, const RunOptions& run_options,
+                      const GraphDef& graph) {
+    ReplayOp op;
+    ExtendSessionRequest* req = op.mutable_extend_session();
+    op.mutable_extend_session_response();
+    req->set_session_handle(SessionToHandle(session));
+    *req->mutable_graph_def() = graph;
+    if (&run_options == kEmptyRunOptions()) {
+      RUN_WITH_TIMESTAMP(Extend, graph);
+    } else {
+      RUN_WITH_TIMESTAMP(Extend, run_options, graph);
+    }
+
+    return Flush(op);
+  }
+
+  Status RecordClose(Session* session) {
+    return RecordClose(session, *kEmptyRunOptions());
+  }
+
+  // N.B. RunOptions is not stored (it has no entry in CloseRequest)
+  Status RecordClose(Session* session, const RunOptions& run_options) {
+    mutex_lock l(log_mutex_);
+    ReplayOp op;
+    CloseSessionRequest* req = op.mutable_close_session();
+    req->set_session_handle(SessionToHandle(session));
+    op.mutable_close_session_response();
+    if (&run_options == kEmptyRunOptions()) {
+      RUN_WITH_TIMESTAMP(Close);
+    } else {
+      RUN_WITH_TIMESTAMP(Close, run_options);
+    }
+    return Flush(op);
+  }
+
+  Status RecordListDevices(Session* session,
+                           std::vector<DeviceAttributes>* response) {
+    mutex_lock l(log_mutex_);
+    ReplayOp op;
+    ListDevicesRequest* req = op.mutable_list_devices();
+    ListDevicesResponse* resp = op.mutable_list_devices_response();
+    req->set_session_handle(SessionToHandle(session));
+    RUN_WITH_TIMESTAMP(ListDevices, response);
+
+    // TODO(power) -- local vs remote device distinction is lost here!
+    *resp->mutable_local_device() = {response->begin(), response->end()};
+    return Flush(op);
+  }
+
+  Status RecordPRunSetup(Session* session,
+                         const std::vector<string>& input_names,
+                         const std::vector<string>& output_names,
+                         const std::vector<string>& target_nodes,
+                         string* handle) {
+    mutex_lock l(log_mutex_);
+    ReplayOp op;
+    PartialRunSetupRequest* req = op.mutable_partial_run_setup();
+    req->set_session_handle(SessionToHandle(session));
+    for (auto& input : input_names) {
+      req->add_feed(input);
+    }
+    for (auto& output : output_names) {
+      req->add_fetch(output);
+    }
+    for (auto& target : target_nodes) {
+      req->add_target(target);
+    }
+    RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes,
+                       handle);
+    op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle);
+    return Flush(op);
+  }
+
+  Status RecordPRun(Session* session, const string& handle,
+                    const std::vector<std::pair<string, Tensor> >& inputs,
+                    const std::vector<string>& output_names,
+                    std::vector<Tensor>* outputs) {
+    ReplayOp op;
+    RunStepRequest* req = op.mutable_run_step();
+    RunStepResponse* resp = op.mutable_run_step_response();
+    req->set_session_handle(SessionToHandle(session));
+
+    // Mark this step as a partial run for replay.
+    req->set_partial_run_handle(handle);
+    for (auto& input : inputs) {
+      auto* feed = req->add_feed();
+      feed->set_name(input.first);
+      input.second.AsProtoField(feed->mutable_tensor());
+    }
+
+    for (auto& output : output_names) {
+      req->add_fetch(output);
+    }
+
+    RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs);
+
+    for (size_t i = 0; i < outputs->size(); ++i) {
+      const Tensor& tensor = (*outputs)[i];
+      NamedTensorProto* tproto = resp->add_tensor();
+      tensor.AsProtoField(tproto->mutable_tensor());
+      tproto->set_name(output_names[i]);
+    }
+
+    return Flush(op);
+  }
+
+  Status RecordMakeCallable(Session* session,
+                            const CallableOptions& callable_options,
+                            Session::CallableHandle* handle) {
+    ReplayOp op;
+    MakeCallableRequest* req = op.mutable_make_callable();
+    req->set_session_handle(SessionToHandle(session));
+    *req->mutable_options() = callable_options;
+
+    RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle);
+
+    MakeCallableResponse* resp = op.mutable_make_callable_response();
+    resp->set_handle(*handle);
+
+    return Flush(op);
+  }
+
+  Status RecordRunCallable(Session* session, Session::CallableHandle handle,
+                           const std::vector<Tensor>& feed_tensors,
+                           std::vector<Tensor>* fetch_tensors,
+                           RunMetadata* run_metadata) {
+    ReplayOp op;
+    RunCallableRequest* req = op.mutable_run_callable();
+    req->set_session_handle(SessionToHandle(session));
+    req->set_handle(handle);
+    for (auto& tensor : feed_tensors) {
+      tensor.AsProtoField(req->add_feed());
+    }
+    RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors,
+                       run_metadata);
+
+    RunCallableResponse* resp = op.mutable_run_callable_response();
+    if (run_metadata) {
+      *resp->mutable_metadata() = *run_metadata;
+    }
+    for (const Tensor& tensor : *fetch_tensors) {
+      tensor.AsProtoTensorContent(resp->add_fetch());
+    }
+    return Flush(op);
+  }
+
+  Status RecordReleaseCallable(Session* session,
+                               Session::CallableHandle handle) {
+    ReplayOp op;
+    ReleaseCallableRequest* req = op.mutable_release_callable();
+    req->set_session_handle(SessionToHandle(session));
+    req->set_handle(handle);
+    RUN_WITH_TIMESTAMP(ReleaseCallable, handle);
+    return Flush(op);
+  }
+
+ private:
+  Status Flush(const ReplayOp& op) {
+    string buf;
+    op.SerializeToString(&buf);
+    TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
+
+    // Flushing the RecordWriter _does not_ flush the underlying file.
+    TF_RETURN_IF_ERROR(log_writer_->Flush());
+    return log_file_->Flush();
+  }
+
+  mutex log_mutex_;
+  std::unique_ptr<io::RecordWriter> log_writer_;
+  std::unique_ptr<WritableFile> log_file_;
+};
+
+static SessionLogger* global_session_logger() {
+  static SessionLogger* logger = new SessionLogger();
+  return logger;
+}
+
+SessionRef::SessionRef(Session* session) : session_(session) {
+  if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
+    logger_ = global_session_logger();
+    logger_->RecordCreateSession(this->session_.get()).IgnoreError();
+  } else {
+    logger_ = nullptr;
+  }
+}
+
+SessionRef::~SessionRef() = default;
+
+Status SessionRef::CheckNotClosed() {
+  mutex_lock l(run_lock_);
+  if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
+  return ::tensorflow::Status::OK();
+}
+
+// If logging is active, log the start and end time of the operation along with
+// the request and response.
+#define LOG_AND_RUN_OPERATION(OpName, ...)                          \
+  TF_RETURN_IF_ERROR(CheckNotClosed());                             \
+  RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \
+  if (!logger_) {                                                   \
+    return rc.session->OpName(__VA_ARGS__);                         \
+  }                                                                 \
+  return logger_->Record##OpName(rc.session.get(), __VA_ARGS__);
+
+Status SessionRef::Run(const RunOptions& run_options,
+                       const std::vector<std::pair<string, Tensor> >& inputs,
+                       const std::vector<string>& output_tensor_names,
+                       const std::vector<string>& target_node_names,
+                       std::vector<Tensor>* outputs,
+                       RunMetadata* run_metadata) {
+  LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names,
+                        target_node_names, outputs, run_metadata);
+}
+
+Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
+                       const std::vector<string>& output_tensor_names,
+                       const std::vector<string>& target_node_names,
+                       std::vector<Tensor>* outputs) {
+  LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names,
+                        outputs);
+}
+
+Status SessionRef::Create(const GraphDef& graph) {
+  LOG_AND_RUN_OPERATION(Create, graph);
+}
+
+Status SessionRef::Create(const RunOptions& run_options,
+                          const GraphDef& graph) {
+  LOG_AND_RUN_OPERATION(Create, run_options, graph);
+}
+
+Status SessionRef::Extend(const RunOptions& run_options,
+                          const GraphDef& graph) {
+  LOG_AND_RUN_OPERATION(Extend, run_options, graph);
+}
+
+Status SessionRef::Extend(const GraphDef& graph) {
+  LOG_AND_RUN_OPERATION(Extend, graph);
+}
+
+Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
+  LOG_AND_RUN_OPERATION(ListDevices, response);
+}
+
+Status SessionRef::PRunSetup(const std::vector<string>& input_names,
+                             const std::vector<string>& output_names,
+                             const std::vector<string>& target_nodes,
+                             string* handle) {
+  LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes,
+                        handle);
+}
+
+Status SessionRef::PRun(const string& handle,
+                        const std::vector<std::pair<string, Tensor> >& inputs,
+                        const std::vector<string>& output_names,
+                        std::vector<Tensor>* outputs) {
+  LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs);
+}
+
+Status SessionRef::MakeCallable(const CallableOptions& callable_options,
+                                CallableHandle* out_handle) {
+  LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle);
+}
+
+Status SessionRef::RunCallable(CallableHandle handle,
+                               const std::vector<Tensor>& feed_tensors,
+                               std::vector<Tensor>* fetch_tensors,
+                               RunMetadata* run_metadata) {
+  LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors,
+                        run_metadata);
+}
+
+Status SessionRef::ReleaseCallable(CallableHandle handle) {
+  LOG_AND_RUN_OPERATION(ReleaseCallable, handle);
+}
+
+Status SessionRef::Close(const RunOptions& run_options) {
+  TF_RETURN_IF_ERROR(CheckNotClosed());
+  mutex_lock l(run_lock_);
+  Status status;
+  if (logger_) {
+    status = logger_->RecordClose(session_.get(), run_options);
+  } else {
+    status = session_->Close(run_options);
+  }
+  session_.reset();
+  while (run_count_ > 0) {
+    run_finished_.wait(l);
+  }
+  return status;
+}
+
+Status SessionRef::Close() {
+  TF_RETURN_IF_ERROR(CheckNotClosed());
+  mutex_lock l(run_lock_);
+  Status status;
+  if (logger_) {
+    status = logger_->RecordClose(session_.get());
+  } else {
+    status = session_->Close();
+  }
+  session_.reset();
+  while (run_count_ > 0) {
+    run_finished_.wait(l);
+  }
+  return status;
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_ref.h b/tensorflow/python/client/session_ref.h
similarity index 90%
rename from tensorflow/core/common_runtime/session_ref.h
rename to tensorflow/python/client/session_ref.h
index 9459e7e..b0fb12b 100644
--- a/tensorflow/core/common_runtime/session_ref.h
+++ b/tensorflow/python/client/session_ref.h
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+#ifndef TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
+#define TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
 
 #include <memory>
 
@@ -22,6 +22,8 @@
 
 namespace tensorflow {
 
+class SessionLogger;
+
 // A `SessionRef` manages the lifetime of a wrapped `Session` pointer.
 //
 // SessionRef blocks the return of Close() until all pending operations have
@@ -29,8 +31,8 @@
 // subsequent operations on the SessionRef object will return errors::Cancelled.
 class SessionRef : public Session {
  public:
-  SessionRef(Session* session) : session_(session) {}
-  virtual ~SessionRef() {}
+  explicit SessionRef(Session* session);
+  ~SessionRef() override;
 
   Status Create(const GraphDef& graph) override;
   Status Extend(const GraphDef& graph) override;
@@ -78,9 +80,12 @@
   uint64 run_count_ GUARDED_BY(run_lock_) = {0};
   std::shared_ptr<Session> session_;
 
+  // Borrowed reference to global session logger.
+  SessionLogger* logger_;
+
   Status CheckNotClosed();
 };
 
 }  // namespace tensorflow
 
-#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+#endif  // TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 4afc639..f576435 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -61,6 +61,12 @@
 from tensorflow.python.training import server_lib
 from tensorflow.python.util import compat
 
+try:
+  import attr  # pylint:disable=g-import-not-at-top
+except ImportError:
+  attr = None
+
+
 # NOTE(mrry): Dummy shape registration for ops used in the tests, since they
 # don't have C++ op registrations on which to attach C++ shape fns.
 ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
@@ -300,6 +306,82 @@
       self.assertEqual(None, res[2])
       self.assertEqual(44.0, res[1])
 
+  def testFetchAttrs(self):
+    if attr is None:
+      self.skipTest('attr module is unavailable.')
+
+    @attr.s
+    class SampleAttr(object):
+      field1 = attr.ib()
+      field2 = attr.ib()
+
+    val1 = np.array([1.2, 3.4, 5.6])
+    val2 = np.array([[1, 2], [4, 3]])
+    val3 = np.array([10, 20, 30])
+
+    t1 = constant_op.constant(val1)
+    t2 = constant_op.constant(val2)
+
+    sample = SampleAttr(t1, t2)
+    with session.Session() as sess:
+      result = sess.run(sample)
+      self.assertIsInstance(result, SampleAttr)
+      self.assertAllEqual(val1, result.field1)
+      self.assertAllEqual(val2, result.field2)
+
+      result = sess.run(sample, feed_dict={sample.field1: val3})
+      self.assertIsInstance(result, SampleAttr)
+      self.assertAllEqual(val3, result.field1)
+      self.assertAllEqual(val2, result.field2)
+
+  def testFetchNestedAttrs(self):
+    if attr is None:
+      self.skipTest('attr module is unavailable.')
+
+    @attr.s
+    class SampleAttr(object):
+      field0 = attr.ib()
+      field1 = attr.ib()
+
+    v1 = 10
+    v2 = 20
+    v3 = np.float32(1.2)
+    v4 = np.float32(3.4)
+    v5 = np.float64(100.001)
+    v6 = np.float64(-23.451)
+    arr1 = np.array([1.2, 6.7, 3.4])
+    arr2 = np.array([7, 11, 3])
+    sample = SampleAttr(
+        SampleAttr(
+            SampleAttr(constant_op.constant(v1), constant_op.constant(v2)),
+            SampleAttr(constant_op.constant(arr1), constant_op.constant(arr2))),
+        {'A': SampleAttr(constant_op.constant(v3), constant_op.constant(v4)),
+         'B': [SampleAttr(constant_op.constant(v5), constant_op.constant(v6))]})
+
+    with session.Session() as sess:
+      result = sess.run(sample)
+      self.assertIsInstance(result, SampleAttr)
+      self.assertIsInstance(result.field0, SampleAttr)
+      self.assertIsInstance(result.field0.field0, SampleAttr)
+      self.assertIsInstance(result.field0.field1, SampleAttr)
+      self.assertIsInstance(result.field0.field1.field0, np.ndarray)
+      self.assertAllEqual(arr1, result.field0.field1.field0)
+      self.assertIsInstance(result.field0.field1.field1, np.ndarray)
+      self.assertAllEqual(arr2, result.field0.field1.field1)
+      self.assertIsInstance(result.field1, dict)
+      self.assertIn('A', result.field1)
+      self.assertIn('B', result.field1)
+      self.assertIsInstance(result.field1['A'], SampleAttr)
+      self.assertAllEqual(
+          [v3, v4],
+          [result.field1['A'].field0, result.field1['A'].field1])
+      self.assertIsInstance(result.field1['B'], list)
+      self.assertEqual(1, len(result.field1['B']))
+      self.assertIsInstance(result.field1['B'][0], SampleAttr)
+      self.assertAllEqual(
+          [v5, v6],
+          [result.field1['B'][0].field0, result.field1['B'][0].field1])
+
   def testFetchNestingEmptyOneLevel(self):
     with session.Session() as sess:
       a_val = 11.0
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 39a2922..ef7527d 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -463,7 +463,7 @@
 }
 
 // Override default py3 behavior of attempting to encode into Unicode.
-%typemap(out) std::string tensorflow::GetResourceHandleShapeAndType {
+%typemap(out) std::string tensorflow::GetHandleShapeAndType {
   $result = PyBytes_FromStringAndSize($1.data(), $1.size());
 }
 
@@ -782,7 +782,7 @@
 %unignore TF_TryEvaluateConstant_wrapper;
 %noexception TF_TryEvaluateConstant_wrapper;
 %unignore ExtendSession;
-%unignore ResourceHandleShapeAndType;
+%unignore HandleShapeAndType;
 
 %include "tensorflow/python/client/tf_session_helper.h"
 
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index bcd4af2..dc0c10b 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -20,7 +20,6 @@
 #include "tensorflow/c/c_api.h"
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/c/tf_status_helper.h"
-#include "tensorflow/core/common_runtime/session_ref.h"
 #include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/attr_value_util.h"
@@ -31,6 +30,7 @@
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/util/equal_graph_def.h"
+#include "tensorflow/python/client/session_ref.h"
 #include "tensorflow/python/lib/core/ndarray_tensor.h"
 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
 #include "tensorflow/python/lib/core/safe_ptr.h"
diff --git a/tensorflow/python/client/timeline.py b/tensorflow/python/client/timeline.py
index 1e96ac5..c3f3829 100644
--- a/tensorflow/python/client/timeline.py
+++ b/tensorflow/python/client/timeline.py
@@ -588,7 +588,8 @@
       alloc_tensor_set = set()
       alloc_maxes[allocator] = AllocationMaximum(
           timestamp=0, num_bytes=0, tensors=set())
-      for time, num_bytes, name in alloc_list:
+      for time, num_bytes, name in sorted(
+          alloc_list, key=lambda allocation: allocation[0]):
         total_bytes += num_bytes
         if num_bytes < 0:
           alloc_tensor_set.discard(name)
diff --git a/tensorflow/python/client/timeline_test.py b/tensorflow/python/client/timeline_test.py
index 03effde..032bbf7 100644
--- a/tensorflow/python/client/timeline_test.py
+++ b/tensorflow/python/client/timeline_test.py
@@ -163,8 +163,6 @@
     # At least num1 + num2, both float32s (4 bytes each)
     self.assertGreaterEqual(cpu_max.num_bytes, 8)
     self.assertGreater(cpu_max.timestamp, 0)
-    self.assertTrue('num1' in cpu_max.tensors or 'num1/read' in cpu_max.tensors)
-    self.assertTrue('num2' in cpu_max.tensors or 'num2/read' in cpu_max.tensors)
 
   def testManyCPUs(self):
     run_options = config_pb2.RunOptions(
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 1a1ed04..5e8f5d6 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util.tf_export import tf_export
 
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 13)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 21)
 
 
 @tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 631b87a..17d4fec 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -407,3 +407,20 @@
         "//tensorflow/python:tensor_shape",
     ],
 )
+
+tf_py_test(
+    name = "window_dataset_op_test",
+    size = "small",
+    srcs = ["window_dataset_op_test.py"],
+    additional_deps = [
+        "@absl_py//absl/testing:parameterized",
+        "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:errors",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:sparse_tensor",
+        "//tensorflow/python/data/ops:dataset_ops",
+    ],
+)
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
index a35cee5..e7e51df 100644
--- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
@@ -134,7 +134,7 @@
         result.append([value] * value)
       return result * count
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for expected_element in self._interleave(
           repeat(input_values, count), cycle_length, block_length):
         self.assertEqual(expected_element, sess.run(get_next))
@@ -169,7 +169,7 @@
             num_parallel_calls)
     get_next = dataset.make_one_shot_iterator().get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for value in input_values:
         if np.isnan(value):
           with self.assertRaises(errors.InvalidArgumentError):
@@ -195,7 +195,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(10):
         for j in range(2):
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 7685d8d..ae04995 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -397,6 +397,28 @@
       # Randomness is repeatable given same seed
       self.assertAllClose(random_values, random_values_2)
 
+  def testStatefulMapKeepsStateAcrossIterators(self):
+    iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10)
+                .map(lambda _: random_ops.random_uniform((), seed=11))
+                .repeat(1000)
+                .batch(10)
+                .make_initializable_iterator())
+    init_op = iterator.initializer
+    get_next = iterator.get_next()
+
+    with self.cached_session() as sess:
+      sess.run(init_op)
+      random_values = sess.run(get_next)
+
+      # Assert that one of the next 99 batches yielded by the iterator is
+      # different from the first.
+      i = 0
+      while i < 99:
+        if np.any(random_values != sess.run(get_next)):
+          break
+        i += 1
+      self.assertLess(i, 99)
+
   def testMapDict(self):
     iterator = (dataset_ops.Dataset.range(10)
                 .map(lambda x: {"foo": x * 2, "bar": x ** 2})
@@ -731,7 +753,7 @@
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tids = sess.run(get_next)
       self.assertTrue(all(tids[0] == tid for tid in tids))
 # pylint: enable=g-long-lambda
diff --git a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
new file mode 100644
index 0000000..fd43484
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
@@ -0,0 +1,295 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class WindowDatasetTest(test.TestCase, parameterized.TestCase):
+
+  @parameterized.named_parameters(
+      ("1", 20, 14, 7, 1),
+      ("2", 20, 17, 9, 1),
+      ("3", 20, 14, 14, 1),
+      ("4", 20, 10, 14, 1),
+      ("5", 20, 14, 19, 1),
+      ("6", 20, 4, 1, 2),
+      ("7", 20, 2, 1, 6),
+      ("8", 20, 4, 7, 2),
+      ("9", 20, 2, 7, 6),
+      ("10", 1, 10, 4, 1),
+      ("11", 0, 10, 4, 1),
+      ("12", 20, 14, 7, 1, False),
+      ("13", 20, 17, 9, 1, False),
+      ("14", 20, 14, 14, 1, False),
+      ("15", 20, 10, 14, 1, False),
+      ("16", 20, 14, 19, 1, False),
+      ("17", 20, 4, 1, 2, False),
+      ("18", 20, 2, 1, 6, False),
+      ("19", 20, 4, 7, 2, False),
+      ("20", 20, 2, 7, 6, False),
+      ("21", 1, 10, 4, 1, False),
+      ("22", 0, 10, 4, 1, False),
+  )
+  def testWindowDataset(self, count, size, shift, stride, drop_remainder=True):
+    """Tests a dataset that slides a window its input elements."""
+    components = (np.arange(7),
+                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+                  np.array(37.0) * np.arange(7))
+
+    count_t = array_ops.placeholder(dtypes.int64, shape=[])
+    size_t = array_ops.placeholder(dtypes.int64, shape=[])
+    shift_t = array_ops.placeholder(dtypes.int64, shape=[])
+    stride_t = array_ops.placeholder(dtypes.int64, shape=[])
+    drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[])
+
+    def _map_fn(x, y, z):
+      return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+    def _flat_map_fn(x, y, z):
+      return dataset_ops.Dataset.zip((x.batch(batch_size=size_t),
+                                      y.batch(batch_size=size_t),
+                                      z.batch(batch_size=size_t)))
+
+    iterator = dataset_ops.Dataset.from_tensor_slices(components).map(
+        _map_fn).repeat(count).window(
+            size=size_t,
+            shift=shift_t,
+            stride=stride_t,
+            drop_remainder=drop_remainder_t).flat_map(
+                _flat_map_fn).make_initializable_iterator()
+    init_op = iterator.initializer
+    get_next = iterator.get_next()
+
+    self.assertEqual([[None] + list(c.shape[1:]) for c in components],
+                     [t.shape.as_list() for t in get_next])
+
+    with self.cached_session() as sess:
+      sess.run(
+          init_op,
+          feed_dict={
+              count_t: count,
+              size_t: size,
+              shift_t: shift,
+              stride_t: stride,
+              drop_remainder_t: drop_remainder
+          })
+      num_full_batches = max(
+          0, (count * 7 - ((size - 1) * stride + 1)) // shift + 1)
+      for i in range(num_full_batches):
+        result = sess.run(get_next)
+        for component, result_component in zip(components, result):
+          for j in range(size):
+            self.assertAllEqual(component[(i * shift + j * stride) % 7]**2,
+                                result_component[j])
+      if not drop_remainder:
+        num_partial_batches = (count * 7) // shift + (
+            (count * 7) % shift > 0) - num_full_batches
+        for i in range(num_partial_batches):
+          result = sess.run(get_next)
+          for component, result_component in zip(components, result):
+            remaining = (count * 7) - ((num_full_batches + i) * shift)
+            num_elements = remaining // stride + ((remaining % stride) > 0)
+            for j in range(num_elements):
+              self.assertAllEqual(
+                  component[((num_full_batches + i) * shift + j * stride) % 7]
+                  **2, result_component[j])
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
+  @parameterized.named_parameters(
+      ("1", 14, 0, 3, 1),
+      ("2", 14, 3, 0, 1),
+      ("3", 14, 3, 3, 0),
+  )
+  def testWindowDatasetInvalid(self, count, size, shift, stride):
+    count_t = array_ops.placeholder(dtypes.int64, shape=[])
+    size_t = array_ops.placeholder(dtypes.int64, shape=[])
+    shift_t = array_ops.placeholder(dtypes.int64, shape=[])
+    stride_t = array_ops.placeholder(dtypes.int64, shape=[])
+
+    iterator = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(
+        count_t).window(
+            size=size_t, shift=shift_t,
+            stride=stride_t).flat_map(lambda x: x.batch(batch_size=size_t)
+                                     ).make_initializable_iterator()
+    init_op = iterator.initializer
+
+    with self.cached_session() as sess:
+      with self.assertRaises(errors.InvalidArgumentError):
+        sess.run(
+            init_op,
+            feed_dict={
+                count_t: count,
+                size_t: size,
+                shift_t: shift,
+                stride_t: stride
+            })
+
+  def assertSparseValuesEqual(self, a, b):
+    self.assertAllEqual(a.indices, b.indices)
+    self.assertAllEqual(a.values, b.values)
+    self.assertAllEqual(a.dense_shape, b.dense_shape)
+
+  def testWindowSparse(self):
+
+    def _sparse(i):
+      return sparse_tensor.SparseTensorValue(
+          indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+    iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
+        size=5, shift=3, drop_remainder=True).flat_map(
+            lambda x: x.batch(batch_size=5)).make_initializable_iterator()
+    init_op = iterator.initializer
+    get_next = iterator.get_next()
+
+    with self.cached_session() as sess:
+      sess.run(init_op)
+      num_batches = (10 - 5) // 3 + 1
+      for i in range(num_batches):
+        actual = sess.run(get_next)
+        expected = sparse_tensor.SparseTensorValue(
+            indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
+            values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
+            dense_shape=[5, 1])
+        self.assertTrue(sparse_tensor.is_sparse(actual))
+        self.assertSparseValuesEqual(actual, expected)
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
+  def testWindowSparseWithDifferentDenseShapes(self):
+
+    def _sparse(i):
+      return sparse_tensor.SparseTensorValue(
+          indices=array_ops.expand_dims(
+              math_ops.range(i, dtype=dtypes.int64), 1),
+          values=array_ops.fill([math_ops.to_int32(i)], i),
+          dense_shape=[i])
+
+    iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
+        size=5, shift=3, drop_remainder=True).flat_map(
+            lambda x: x.batch(batch_size=5)).make_initializable_iterator()
+    init_op = iterator.initializer
+    get_next = iterator.get_next()
+
+    with self.cached_session() as sess:
+      sess.run(init_op)
+      num_batches = (10 - 5) // 3 + 1
+      for i in range(num_batches):
+        actual = sess.run(get_next)
+        expected_indices = []
+        expected_values = []
+        for j in range(5):
+          for k in range(i * 3 + j):
+            expected_indices.append([j, k])
+            expected_values.append(i * 3 + j)
+        expected = sparse_tensor.SparseTensorValue(
+            indices=expected_indices,
+            values=expected_values,
+            dense_shape=[5, i * 3 + 5 - 1])
+        self.assertTrue(sparse_tensor.is_sparse(actual))
+        self.assertSparseValuesEqual(actual, expected)
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
+  def testNestedWindowSparse(self):
+
+    def _sparse(i):
+      return sparse_tensor.SparseTensorValue(
+          indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+    iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
+        size=4, shift=2,
+        drop_remainder=True).flat_map(lambda x: x.batch(batch_size=4)).window(
+            size=3, shift=1, drop_remainder=True).flat_map(
+                lambda x: x.batch(batch_size=3)).make_initializable_iterator()
+    init_op = iterator.initializer
+    get_next = iterator.get_next()
+
+    with self.cached_session() as sess:
+      sess.run(init_op)
+      # Slide: 1st batch.
+      actual = sess.run(get_next)
+      expected = sparse_tensor.SparseTensorValue(
+          indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
+                   [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
+                   [2, 2, 0], [2, 3, 0]],
+          values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
+          dense_shape=[3, 4, 1])
+      self.assertTrue(sparse_tensor.is_sparse(actual))
+      self.assertSparseValuesEqual(actual, expected)
+      # Slide: 2nd batch.
+      actual = sess.run(get_next)
+      expected = sparse_tensor.SparseTensorValue(
+          indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
+                   [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
+                   [2, 2, 0], [2, 3, 0]],
+          values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
+          dense_shape=[3, 4, 1])
+      self.assertTrue(sparse_tensor.is_sparse(actual))
+      self.assertSparseValuesEqual(actual, expected)
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
+  def testWindowShapeError(self):
+
+    def generator():
+      yield [1.0, 2.0, 3.0]
+      yield [4.0, 5.0, 6.0]
+      yield [7.0, 8.0, 9.0, 10.0]
+
+    iterator = dataset_ops.Dataset.from_generator(
+        generator, dtypes.float32, output_shapes=[None]).window(
+            size=3, shift=1).flat_map(
+                lambda x: x.batch(batch_size=3)).make_initializable_iterator()
+    next_element = iterator.get_next()
+
+    with self.cached_session() as sess:
+      sess.run(iterator.initializer)
+      with self.assertRaisesRegexp(
+          errors.InvalidArgumentError,
+          r"Cannot batch tensors with different shapes in component 0. "
+          r"First element had shape \[3\] and element 2 had shape \[4\]."):
+        sess.run(next_element)
+
+  def testWindowIgnoreErrors(self):
+    input_values = np.float32([1., np.nan, 2., np.nan, 3.])
+    dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
+        lambda x: array_ops.check_numerics(x, "message")).window(
+            size=2, shift=2, stride=2,
+            drop_remainder=True).flat_map(lambda x: x.batch(batch_size=2))
+    get_next = dataset.make_one_shot_iterator().get_next()
+
+    with self.cached_session() as sess:
+      self.assertAllEqual(np.float32([1., 2.]), sess.run(get_next))
+      self.assertAllEqual(np.float32([2., 3.]), sess.run(get_next))
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index c985e00..93b3a7b 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1115,7 +1115,7 @@
     return FilterDataset(self, predicate)
 
   def apply(self, transformation_func):
-    """Apply a transformation function to this dataset.
+    """Applies a transformation function to this dataset.
 
     `apply` enables chaining of custom `Dataset` transformations, which are
     represented as functions that take one `Dataset` argument and return a
@@ -1131,7 +1131,7 @@
 
     Args:
       transformation_func: A function that takes one `Dataset` argument and
-          returns a `Dataset`.
+        returns a `Dataset`.
 
     Returns:
       Dataset: The `Dataset` returned by applying `transformation_func` to this
@@ -1142,6 +1142,45 @@
       raise TypeError("`transformation_func` must return a Dataset.")
     return dataset
 
+  def window(self, size, shift=None, stride=1, drop_remainder=False):
+    """Combines input elements into a dataset of windows.
+
+    Each window is a dataset itself and contains `size` elements (or
+    possibly fewer if there are not enough input elements to fill the window
+    and `drop_remainder` evaluates to false).
+
+    The `stride` argument determines the stride of the input elements,
+    and the `shift` argument determines the shift of the window.
+
+    For example:
+    - `tf.data.Dataset.range(7).window(2)` produces
+      `{{0, 1}, {2, 3}, {4, 5}, {6}}`
+    - `tf.data.Dataset.range(7).window(3, 2, 1, True)` produces
+      `{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}`
+    - `tf.data.Dataset.range(7).window(3, 1, 2, True)` produces
+      `{{0, 2, 4}, {1, 3, 5}, {2, 4, 6}}`
+
+    Args:
+      size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements
+        of the input dataset to combine into a window.
+      shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+        forward shift of the sliding window in each iteration. Defaults to
+        `size`.
+      stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+        stride of the input elements in the sliding window.
+      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+        whether a window should be dropped in case its size is smaller than
+        `window_size`.
+
+    Returns:
+      Dataset: A `Dataset` of windows, each of which is a nested `Dataset` with
+        the same structure as this dataset, but a finite subsequence of its
+        elements.
+    """
+    if shift is None:
+      shift = size
+    return WindowDataset(self, size, shift, stride, drop_remainder)
+
 
 class TensorDataset(Dataset):
   """A `Dataset` with a single element, viz. a nested structure of tensors."""
@@ -2442,3 +2481,53 @@
   @property
   def output_types(self):
     return self._input_dataset.output_types
+
+
+class WindowDataset(Dataset):
+  """A dataset that creates window datasets from the input elements."""
+
+  def __init__(self, input_dataset, size, shift, stride, drop_remainder):
+    """See `window_dataset()` for more details."""
+    super(WindowDataset, self).__init__()
+    self._input_dataset = input_dataset
+    self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
+    self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
+    self._stride = ops.convert_to_tensor(
+        stride, dtype=dtypes.int64, name="stride")
+    self._drop_remainder = ops.convert_to_tensor(
+        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
+    self._output_classes = nest.pack_sequence_as(
+        input_dataset.output_classes,
+        [
+            _NestedDatasetComponent(  # pylint: disable=protected-access
+                output_classes=output_class,
+                output_shapes=output_shape,
+                output_types=output_type)
+            for output_class, output_shape, output_type in zip(
+                nest.flatten(input_dataset.output_classes),
+                nest.flatten(input_dataset.output_shapes),
+                nest.flatten(input_dataset.output_types))
+        ])
+    self._output_shapes = self._output_classes
+    self._output_types = self._output_classes
+
+  def _as_variant_tensor(self):
+    return gen_dataset_ops.window_dataset(
+        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
+        self._size,
+        self._shift,
+        self._stride,
+        self._drop_remainder,
+        **flat_structure(self))
+
+  @property
+  def output_classes(self):
+    return self._output_classes
+
+  @property
+  def output_shapes(self):
+    return self._output_shapes
+
+  @property
+  def output_types(self):
+    return self._output_types
diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py
index ff49b69..91f21cb 100644
--- a/tensorflow/python/debug/lib/session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py
@@ -741,7 +741,7 @@
      debug_server) = grpc_debug_test_server.start_server_on_separate_thread(
          server_start_delay_sec=2.0, dump_to_filesystem=False)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       a_init = constant_op.constant(42.0, name="a_init")
       a = variables.Variable(a_init, name="a")
 
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index c1bc27d..a2686c6 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -34,6 +34,7 @@
         "//tensorflow/python:safe_ptr",
         "//third_party/py/numpy:headers",
         "//third_party/python_runtime:headers",
+        "@com_google_absl//absl/types:variant",
     ],
 )
 
@@ -146,6 +147,7 @@
         "//tensorflow/python:clip_ops",
         "//tensorflow/python:init_ops",
         "//tensorflow/python:layers",
+        "//tensorflow/python:list_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:resource_variable_ops",
     ],
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index be392c7..d95e0fe 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -120,27 +120,6 @@
 pywrap_tensorflow.TFE_Py_RegisterGradientFunction(_gradient_function)
 
 
-_tracing = False
-
-
-# TODO(agarwal): use an automatic mechanism for handling None arguments to
-# gradient functions.
-# Some gradient functions can accept None arguments for gradients. The following
-# maps the operation name to the indices at which the corresponding gradient
-# function can accept None values.
-# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values
-# during backprop. However the gradient function uses only the first of those
-# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4],
-# indicates that only the gradient corresponding to index 0 is used, and the
-# gradient values at indices 1-4 are ignored (and hence can be None). The
-# backprop algorithm can then leverage this by not constructing zeros to
-# pass for those indices.
-_grad_fn_accepts_none_for_indices = {
-    "SoftmaxCrossEntropyWithLogits": [1],
-    "FusedBatchNorm": [1, 2, 3, 4]
-}
-
-
 def _record_gradient(op_name, inputs, attrs, results, name):
   return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs,
                                                  results, name)
@@ -629,8 +608,9 @@
 _default_vspace = imperative_grad.VSpace(
     num_elements_fn=_num_elements,
     aggregate_fn=_aggregate_grads,
-    zeros=_zeros,
-    ones=_ones)
+    zeros_fn=_zeros,
+    ones_fn=_ones,
+    graph_shape_fn=gen_array_ops.shape)
 pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
 
 
@@ -648,8 +628,8 @@
   Operations are recorded if they are executed within this context manager and
   at least one of their inputs is being "watched".
 
-  Trainable variables (created by `tf.Variable` or `tf.get_variable`,
-  trainable=True is default in both cases) are automatically watched. Tensors
+  Trainable variables (created by `tf.Variable` or `tf.get_variable`, where
+  `trainable=True` is default in both cases) are automatically watched. Tensors
   can be manually watched by invoking the `watch` method on this context
   manager.
 
@@ -669,6 +649,7 @@
   ```python
   x = tf.constant(3.0)
   with tf.GradientTape() as g:
+    g.watch(x)
     with tf.GradientTape() as gg:
       gg.watch(x)
       y = x * x
@@ -745,7 +726,9 @@
     self._persistent = persistent
     self._watch_accessed_variables = watch_accessed_variables
     self._recording = False
-    context.context().start_step()
+    self._created_eagerly = context.executing_eagerly()
+    if self._created_eagerly:
+      context.context().start_step()
 
   def __enter__(self):
     """Enters a context inside which operations are recorded on this tape."""
@@ -775,7 +758,8 @@
     self._recording = False
 
   def __del__(self):
-    context.context().end_step()
+    if self._created_eagerly:
+      context.context().end_step()
 
   def watch(self, tensor):
     """Ensures that `tensor` is being traced by this tape.
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index f938ed5..3273174 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -1022,6 +1022,18 @@
         resource_variable_ops.ResourceVariable(2.0))
     self.assertAllEqual(gradients_constants, gradients_variables)
 
+  def testUnknownShapes(self):
+    with context.graph_mode():
+      with backprop.GradientTape() as tape:
+        a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
+        tape.watch(a)
+        b = a**3
+
+      db_da = tape.gradient(b, a)
+
+      with self.cached_session() as sess:
+        self.assertEqual((8.0, 12.0), sess.run((b, db_da), feed_dict={a: 2.0}))
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 552ed29..bcb1881 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -23,6 +23,7 @@
 import functools
 import sys
 import threading
+import weakref
 
 import numpy as np
 import six
@@ -65,23 +66,43 @@
 WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_"
 
 
-def _create_substitute_placeholder(value, name, dtype=None):
+def _create_substitute_placeholder(value, name=None, dtype=None):
   """Creates a placeholder for `value` and propagates shape info to it."""
   # Note: setting ops.control_dependencies(None) ensures we always put
   # capturing placeholders outside of any control flow context.
   with ops.control_dependencies(None):
     placeholder = graph_placeholder(
         dtype=dtype or value.dtype, shape=value.shape, name=name)
-  if placeholder.dtype == dtypes_module.resource:
-    if isinstance(value, ops.EagerTensor):
-      handle_data = value._handle_data  # pylint: disable=protected-access
+  _copy_handle_data(value, placeholder)
+  return placeholder
+
+
+def _copy_handle_data(source_t, target_t):
+  """Copies HandleData for variant and resource type tensors if available.
+
+  The CppShapeInferenceResult::HandleData proto contains information about the
+  shapes and types of the element tensors of resource/variant type tensors.
+  We need to copy this across function boundaries, i.e., when capturing a
+  placeholder or when returning a function tensor as output. If we don't do this
+  the element tensors will have unknown shapes, e.g., if a TensorList variant
+  tensor is captured as a placeholder, elements popped from that list would have
+  unknown shape.
+
+  Args:
+    source_t: The tensor to copy HandleData from.
+    target_t: The tensor to copy HandleData to.
+  """
+  if (target_t.dtype == dtypes_module.resource or
+      target_t.dtype == dtypes_module.variant):
+    if isinstance(source_t, ops.EagerTensor):
+      handle_data = source_t._handle_data  # pylint: disable=protected-access
     else:
-      handle_data = resource_variable_ops.get_resource_handle_data(value)
+      handle_data = resource_variable_ops.get_resource_handle_data(source_t)
     if handle_data is not None and handle_data.is_set:
       # pylint: disable=protected-access
-      pywrap_tensorflow.SetResourceHandleShapeAndType(
-          placeholder.graph._c_graph, placeholder._as_tf_output(),
-          handle_data.SerializeToString())
+      pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
+                                              target_t._as_tf_output(),
+                                              handle_data.SerializeToString())
       # pylint: enable=protected-access
       # Ensure that shapes and dtypes are propagated.
       shapes, types = zip(*[(pair.shape, pair.dtype)
@@ -90,12 +111,10 @@
       shapes = [[d.size for d in s.dim]
                 if not s.unknown_rank else None for s in shapes]
       pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
-          placeholder._op._graph._c_graph,  # pylint: disable=protected-access
-          placeholder._as_tf_output(),  # pylint: disable=protected-access
+          target_t._op._graph._c_graph,  # pylint: disable=protected-access
+          target_t._as_tf_output(),  # pylint: disable=protected-access
           shapes, ranks, types)
 
-  return placeholder
-
 
 def _get_device_functions(ctx, graph):
   """Returns a tuple of device functions representing the device stack."""
@@ -180,7 +199,7 @@
     self.inputs = []
     self.outputs = []
     self.structured_outputs = None
-    self.variables = []
+    self._weak_variables = []
     self.outer_graph = ops.get_default_graph()
     self.captures = collections.OrderedDict()
 
@@ -217,6 +236,31 @@
     self._graph_key = graph._graph_key
     # pylint: enable=protected-access
 
+  @property
+  def variables(self):
+    """A list of variables accessed by this FuncGraph.
+
+    Note that functions keep only weak references to variables. Calling the
+    function after a variable it accesses has been deleted is an error.
+
+    Yields:
+      Strong references to variables accessed by this FuncGraph.
+    """
+    for weak_v in self._weak_variables:
+      v = weak_v()
+      if v is None:
+        raise AssertionError(
+            "Called a function referencing variables which have been deleted. "
+            "This likely means that function-local variables were created and "
+            "not referenced elsewhere in the program. This is generally a "
+            "mistake; consider storing variables in an object attribute on "
+            "first call.")
+      yield v
+
+  @variables.setter
+  def variables(self, var_list):
+    self._weak_variables = [weakref.ref(v) for v in var_list]
+
   def create_op(
       self,
       op_type,
@@ -409,6 +453,7 @@
     self._num_outputs = len(self.signature.output_arg)
     self._output_types = [o.type for o in self.signature.output_arg]
     self._output_shapes = [o.shape for o in outputs]
+    self._func_graph_outputs = outputs
     self.grad_func_name = None
     self.python_grad_func = None
     self._c_func = c_api_util.ScopedTFFunction(fn)
@@ -485,6 +530,8 @@
     else:
       for i, shape in enumerate(self._output_shapes):
         outputs[i].set_shape(shape)
+      for i, func_graph_output in enumerate(self._func_graph_outputs):
+        _copy_handle_data(func_graph_output, outputs[i])
       return outputs
 
 
@@ -529,7 +576,7 @@
     self._num_outputs = len(self._func_graph.outputs)
     self._output_shapes = tuple(
         output.shape for output in self._func_graph.outputs)
-    self._attrs = _parse_func_attrs(attrs)
+    self._attrs = _parse_func_attrs(attrs or {})
     self._device_functions = tuple(
         self._func_graph._device_functions_outer_to_inner)  # pylint: disable=protected-access
 
@@ -550,7 +597,19 @@
           self._distributed_variables[component_variable.handle] = variable
 
   def __call__(self, *args):
-    """Executes the wrapped function."""
+    """Executes the wrapped function.
+
+    Args:
+      *args: a list of Tensors or Variables.
+
+    Returns:
+      The result of applying the TF function to `args`.
+
+    Raises:
+      ValueError: If the current device stack does not match the device stack
+        under which the function was created, or if `args` contains anything
+        other than Tensors or Variables.
+    """
     ctx = context.context()
     device_functions = _get_device_functions(ctx, ops.get_default_graph())
     if device_functions != self._device_functions:
@@ -566,7 +625,18 @@
         tape.variable_accessed(v)
 
     captures = self._resolve_captured_inputs()
-    tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
+    tensor_inputs = []
+    for i, arg in enumerate(nest.flatten(args)):
+      if isinstance(arg, resource_variable_ops.ResourceVariable):
+        if arg.trainable:
+          tape.variable_accessed(arg)
+        tensor_inputs.append(arg.handle)
+      elif isinstance(arg, ops.Tensor):
+        tensor_inputs.append(arg)
+      else:
+        raise ValueError("All inputs to `Function`s must be Tensors; "
+                         "on invocation of %s, the %d-th input (%s) was not a "
+                         "Tensor." % (self._func_graph.name, i, str(arg)))
     args = tensor_inputs + captures
 
     if tape.should_record(tensor_inputs) or tape.should_record(captures):
@@ -581,11 +651,6 @@
     return self._func_graph
 
   @property
-  def variables(self):
-    """Returns all variables touched by this function."""
-    return self._func_graph.variables
-
-  @property
   def inputs(self):
     """Returns tensors in `self.graph` corresponding to arguments."""
     return self._func_graph.inputs
@@ -782,7 +847,12 @@
   return nest.pack_sequence_as(args, function_inputs)
 
 
-def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
+def func_graph_from_py_func(name,
+                            python_func,
+                            args,
+                            kwds,
+                            signature=None,
+                            func_graph=None):
   """Returns a `FuncGraph` generated from `python_func`.
 
   Args:
@@ -797,6 +867,8 @@
       `kwds` are ignored, and `python_func` is traced with Tensors conforming
       to `signature`. If `None`, the shapes and dtypes are inferred from the
       inputs.
+    func_graph: Optional. An instance of FuncGraph. If provided, we will use
+      this graph else a new one is built and returned.
 
   Returns:
     A FuncGraph.
@@ -805,7 +877,9 @@
     TypeError: If any of `python_func`'s return values is neither `None` nor a
       `Tensor`.
   """
-  func_graph = FuncGraph(name)
+  if func_graph is None:
+    func_graph = FuncGraph(name)
+  assert isinstance(func_graph, FuncGraph)
   with func_graph.as_default(), AutomaticControlDependencies() as a:
     variable_scope.get_variable_scope().set_use_resource(True)
 
@@ -817,10 +891,6 @@
       func_kwds = {}
 
     # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
-    func_graph.inputs.extend(
-        x for x in nest.flatten(func_args) + nest.flatten(func_kwds)
-        if isinstance(x, ops.Tensor))
-
     # Variables to help check whether mutation happens in calling the function
     # Copy the recursive list, tuple and map structure, but not base objects
     func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
@@ -867,6 +937,28 @@
     finally:
       tape.pop_tape(this_tape)
 
+    # Variables in `func_args`, `func_kwds` should be explicit inputs
+    # to the function, not captured inputs.
+    tape_variables = this_tape.watched_variables()
+    arg_variables = set()
+    inputs = []
+    for arg in nest.flatten(func_args) + nest.flatten(func_kwds):
+      if isinstance(arg, resource_variable_ops.ResourceVariable):
+        try:
+          resource_placeholder = func_graph.captures.pop(arg.handle)
+          arg_variables.add(arg)
+        except KeyError:
+          # This case occurs if a Variable among the inputs is not actually
+          # used by the function; we still add an explicit input for it
+          # because the user should presumably pass the Variable as an input
+          # to the corresponding graph function.
+          resource_placeholder = _create_substitute_placeholder(arg.handle)
+        inputs.append(resource_placeholder)
+      elif isinstance(arg, ops.Tensor):
+        inputs.append(arg)
+    variables = [v for v in tape_variables if v not in arg_variables]
+    func_graph.inputs = inputs + list(func_graph.captures.values())
+
     func_graph.structured_outputs = func_outputs
     # Returning a closed-over tensor does not trigger convert_to_tensor.
     func_graph.outputs.extend(
@@ -878,7 +970,6 @@
     # Instead of storing non-distributed component variables, we
     # store their distributed containers so we can retrieve the correct
     # component variables at call-time.
-    variables = list(this_tape.watched_variables())
     strategy = distribution_strategy_context.get_distribution_strategy()
     for i, variable in enumerate(variables):
       # If variable is not distributed value_container returns itself.
@@ -930,7 +1021,16 @@
     return tuple(
         (_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg))
   else:
-    return arg
+    try:
+      # If possible, keep only a weak reference to Python objects. Weak
+      # references hash to the same value as the original object.
+      # TODO(allenl): Clean up dead functions and their cache keys if the cache
+      # gets large. Right now creating objects with a defunned method, calling
+      # the method, and losing a reference to the object in a loop will leak
+      # memory here.
+      return weakref.ref(arg)
+    except TypeError:
+      return arg
 
 
 def _deterministic_dict_values(dictionary):
@@ -980,7 +1080,6 @@
       self._kwds_to_include = {}
     self._name = name
     self._function_cache = collections.OrderedDict()
-    self._variables = []
     self._function_attributes = attributes or {}
 
     self._lock = threading.Lock()
@@ -1026,12 +1125,6 @@
     """Returns the wrapped Python function."""
     return self._python_function
 
-  # TODO(akshayka): Remove this property.
-  @property
-  def variables(self):
-    """Returns the union of all variables referenced by cached `Function`s`."""
-    return self._variables
-
   def get_concrete_function(self, *args, **kwargs):
     """Returns a `Function` object specialized to inputs and execution context.
 
@@ -1198,10 +1291,11 @@
             func_graph_from_py_func(self._name, self._python_function, args,
                                     kwds, self._input_signature),
             self._function_attributes)
-        self._variables.extend(
-            [v for v in graph_function.variables if v not in self._variables])
         self._function_cache[cache_key] = graph_function
-      return graph_function, (args, kwds)
+      return graph_function, [
+          t for t in nest.flatten((args, kwds))
+          if isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable))
+      ]
 
 
 def register(func, *args, **kwargs):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index a0abefe..e4513cc 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -21,10 +21,12 @@
 import functools
 from multiprocessing.pool import ThreadPool
 import sys
+import weakref
 
 import numpy
 
 from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python import keras
 from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.eager import backprop
@@ -46,6 +48,7 @@
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import list_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import resource_variable_ops
@@ -74,6 +77,13 @@
     return self.fc(inputs)
 
 
+class DefunnedMiniModel(MiniModel):
+
+  @function.defun
+  def call(self, inputs, training=True):
+    return super(DefunnedMiniModel, self).call(inputs, training=training)
+
+
 @test_util.with_c_shapes
 class FunctionTest(test.TestCase):
 
@@ -140,8 +150,8 @@
 
     @function.defun
     def f():
-      v = resource_variable_ops.ResourceVariable(1.0)
-      return v.read_value()
+      self.v = resource_variable_ops.ResourceVariable(1.0)
+      return self.v.read_value()
 
     self.assertAllEqual(f(), 1.0)
 
@@ -399,9 +409,9 @@
 
     @function.defun
     def tensor_init():
-      v = resource_variable_ops.ResourceVariable(
+      self.v = resource_variable_ops.ResourceVariable(
           lambda: constant_op.constant(2.0))
-      return v.read_value()
+      return self.v.read_value()
 
     value = tensor_init()
     if not context.executing_eagerly():
@@ -415,8 +425,8 @@
     def tensor_init():
       with ops.init_scope():
         const = constant_op.constant(2.0)
-      v = resource_variable_ops.ResourceVariable(const)
-      return v.read_value()
+      self.v = resource_variable_ops.ResourceVariable(const)
+      return self.v.read_value()
 
     value = tensor_init()
     if not context.executing_eagerly():
@@ -429,10 +439,17 @@
     def f():
       x = constant_op.constant([[1, 2], [3, 4]])
       out = math_ops.matmul(v, x)
-      self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+      self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+      # We do not return v directly since the tensor conversion function of
+      # ResourceVariable returns the read value and not the resource itself.
+      return v._handle
 
     compiled = function.defun(f)
-    compiled()
+    var_handle = compiled()
+    self.assertEqual(var_handle.dtype, dtypes.resource)
+    self.assertEqual(var_handle.shape, tensor_shape.scalar())
+    var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+    self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
 
   def testVariableInLoopInFunction(self):
 
@@ -456,10 +473,17 @@
       def f():
         x = constant_op.constant([[1, 2], [3, 4]])
         out = math_ops.matmul(v, x)
-        self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+        self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+        # We do not return v directly since the tensor conversion function of
+        # ResourceVariable returns the read value and not the resource itself.
+        return v._handle
 
       compiled = function.defun(f)
-      compiled()
+      var_handle = compiled()
+      self.assertEqual(var_handle.dtype, dtypes.resource)
+      self.assertEqual(var_handle.shape, tensor_shape.scalar())
+      var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+      self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
 
   def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
     with context.graph_mode():
@@ -468,23 +492,46 @@
       def f():
         x = constant_op.constant([[1, 2], [3, 4]])
         out = math_ops.matmul(v, x)
-        self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+        self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
 
       # Check that shape inference works while creating the defun
       compiled = function.defun(f)
       compiled()
 
+  def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
+    with context.graph_mode():
+      tensor_list = list_ops.empty_tensor_list(
+          element_dtype=dtypes.float32,
+          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+      tensor_list = list_ops.tensor_list_push_back(tensor_list,
+                                                   constant_op.constant(1.0))
+      tensor_list = list_ops.tensor_list_push_back(tensor_list,
+                                                   constant_op.constant(2.0))
+
+      def f():
+        tl, value = list_ops.tensor_list_pop_back(
+            tensor_list, element_dtype=dtypes.float32)
+        self.assertEqual(value.shape, tensor_shape.scalar())
+        return tl
+
+      compiled = function.defun(f)
+      output_tensor_list = compiled()
+      _, value = list_ops.tensor_list_pop_back(
+          output_tensor_list, element_dtype=dtypes.float32)
+      self.assertEqual(value.shape, tensor_shape.scalar())
+
   @test_util.run_in_graph_and_eager_modes
   def testDefunForcesResourceVariables(self):
 
     def variable_creator():
-      return variables.Variable(0.0).read_value()
+      self.v = variables.Variable(0.0)
+      return self.v.read_value()
 
+    self.v = None
     defined = function.defun(variable_creator)
     defined()  # Create the variable.
-    self.assertEqual(len(defined.variables), 1)
     self.assertIsInstance(
-        defined.variables[0], resource_variable_ops.ResourceVariable)
+        self.v, resource_variable_ops.ResourceVariable)
 
   def testDefunDifferentiable(self):
     v = resource_variable_ops.ResourceVariable(1.0)
@@ -1184,13 +1231,11 @@
     defined = function.defun(foo)
 
     x = constant_op.constant([1.0])
-    self.assertAllEqual(defined.variables, [])
-    _ = defined(x)
-    self.assertAllEqual(defined.variables, [v])
+    self.assertEqual(1., self.evaluate(defined(x)))
+    v.assign(2.)
 
     x = constant_op.constant([1.0, 2.0])
-    _ = defined(x)  # ensure the variables list remains the same
-    self.assertAllEqual(defined.variables, [v])
+    self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
 
   def testPythonFunctionWithDefaultArgs(self):
 
@@ -1685,6 +1730,88 @@
         # pylint: disable=protected-access
         self.assertEqual(len(graph._functions), 1)
 
+  def testCallingFunctionWithDifferentVariables(self):
+
+    @function.defun
+    def foo(v):
+      v.assign_add(1.0)
+      return v.read_value()
+
+    v = resource_variable_ops.ResourceVariable(0.0)
+    graph_function = foo.get_concrete_function(v)
+    self.assertEqual(len(graph_function.inputs), 1)
+    self.assertEqual(len(graph_function.captured_inputs), 0)
+
+    self.assertEqual(float(graph_function(v)), 1.0)
+    self.assertEqual(float(graph_function(v)), 2.0)
+
+    w = resource_variable_ops.ResourceVariable(0.0)
+
+    @function.defun
+    def bar(v):
+      del v
+      return constant_op.constant(1.0)
+
+    graph_function = bar.get_concrete_function(v)
+    self.assertEqual(float(graph_function(v)), 1.0)
+    self.assertEqual(float(graph_function(w)), 1.0)
+
+  def testCallingFunctionWithNonTensorsFails(self):
+
+    @function.defun
+    def foo(x):
+      return x
+
+    graph_function = foo.get_concrete_function(constant_op.constant(1.0))
+    with self.assertRaisesRegexp(ValueError, 'All inputs to `Function`s must '
+                                 'be Tensors;.*'):
+      graph_function('Not a Tensor.')
+
+  def testSwapImplementationWithGrapplerPlugin(self):
+    rewrites = rewriter_config_pb2.RewriterConfig()
+    # function_optimizer has to be turn off, otherwise it will delete the
+    # registered function if it does not get called.
+    # TODO(scottzhu): Move the ExperimentalImplementationSelector to be called
+    # before function_optimizer in future.
+    rewrites.function_optimization = rewriter_config_pb2.RewriterConfig.OFF
+    customer_optimizer = rewrites.custom_optimizers.add()
+    customer_optimizer.name = 'ExperimentalImplementationSelector'
+    rewrites.min_graph_nodes = -1
+    graph_options = config_pb2.GraphOptions(
+        rewrite_options=rewrites, build_cost_model=1)
+    config = config_pb2.ConfigProto(graph_options=graph_options)
+
+    with context.graph_mode(), self.cached_session(
+        config=config, graph=ops.Graph(), use_gpu=True) as sess:
+
+      @function.defun_with_attributes(
+          attributes={
+              'experimental_api_implements': 'random_boost',
+              'experimental_api_preferred_device': 'CPU'
+          })
+      def cpu_boost(x):
+        return math_ops.add(x, 2.0)
+
+      @function.defun_with_attributes(
+          attributes={
+              'experimental_api_implements': 'random_boost',
+              'experimental_api_preferred_device': 'GPU'
+          })
+      def gpu_boost(x):
+        return math_ops.add(x, 4.0)
+
+      x = constant_op.constant(1.0)
+
+      function.register(cpu_boost, x)
+      y = gpu_boost(x)
+      y_value = sess.run(y)
+
+      if test.is_gpu_available():
+        self.assertEquals(y_value, 5.0)
+      else:
+        # Grappler fallback to use the CPU impl even called with GPU function.
+        self.assertEquals(y_value, 3.0)
+
 
 @test_util.with_c_shapes
 class AutomaticControlDependenciesTest(test.TestCase):
@@ -1876,10 +2003,10 @@
 
     @function.defun
     def train():
-      v = resource_variable_ops.ResourceVariable(1.0)
-      grad = backprop.implicit_grad(loss)(v)
+      self.v = resource_variable_ops.ResourceVariable(1.0)
+      grad = backprop.implicit_grad(loss)(self.v)
       optimizer.apply_gradients(grad)
-      return v.read_value()
+      return self.v.read_value()
 
     value = train()
     self.assertEqual(value.numpy(), -1.0)
@@ -1906,10 +2033,10 @@
 
     @function.defun
     def train():
-      v = resource_variable_ops.ResourceVariable(1.0)
-      grad = backprop.implicit_grad(loss)(v)
+      self.v = resource_variable_ops.ResourceVariable(1.0)
+      grad = backprop.implicit_grad(loss)(self.v)
       optimizer.apply_gradients(grad)
-      return v.read_value()
+      return self.v.read_value()
 
     train()
 
@@ -2096,6 +2223,13 @@
 
       modify_same_flat(nested_input)
 
+  def testDecoratedMethodVariableCleanup(self):
+    m = DefunnedMiniModel()
+    m(array_ops.ones([1, 2]))
+    weak_variables = weakref.WeakSet(m.variables)
+    self.assertEqual(2, len(weak_variables))
+    del m
+    self.assertEqual([], list(weak_variables))
 
 if __name__ == '__main__':
   ops.enable_eager_execution(
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 5f027d1..5f5af4a 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -23,8 +23,9 @@
 from tensorflow.python import pywrap_tensorflow
 
 
-VSpace = collections.namedtuple(
-    "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"])
+VSpace = collections.namedtuple("VSpace", [
+    "aggregate_fn", "num_elements_fn", "zeros_fn", "ones_fn", "graph_shape_fn"
+])
 
 
 def imperative_grad(
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index f34ce6a..5f44bd4 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -516,25 +516,13 @@
 // Getter for `_num_elements`.
 static PyObject* EagerTensor_num_elements(EagerTensor* self) {
   auto handle = self->handle;
-  int n = TFE_TensorHandleNumDims(handle, self->status);
+  int n = TFE_TensorHandleNumElements(handle, self->status);
   if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
     // Cleanup self->status before returning.
     TF_SetStatus(self->status, TF_OK, "");
     return nullptr;
   }
-  tensorflow::int64 value = 1;
-  if (PyErr_Occurred()) return nullptr;
-  for (int i = 0; i < n; ++i) {
-    int64_t dim = TFE_TensorHandleDim(handle, i, self->status);
-    if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
-      // Cleanup self->status before returning.
-      TF_SetStatus(self->status, TF_OK, "");
-      PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions");
-      return nullptr;
-    }
-    value *= dim;
-  }
-  return PyLong_FromLongLong(value);
+  return PyLong_FromLongLong(n);
 }
 
 static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
@@ -777,17 +765,34 @@
   return reinterpret_cast<PyObject*>(t);
 }
 
-tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
-  CHECK(EagerTensor_CheckExact(tensor));
+tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) {
+  DCHECK(EagerTensor_CheckExact(tensor));
   return reinterpret_cast<const EagerTensor*>(tensor)->id;
 }
 
-tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) {
-  CHECK(EagerTensor_CheckExact(tensor));
+tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
+  DCHECK(EagerTensor_CheckExact(tensor));
   return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
       reinterpret_cast<const EagerTensor*>(tensor)->handle));
 }
 
+tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) {
+  DCHECK(EagerTensor_CheckExact(tensor));
+  const EagerTensor* as_c_eager_tensor =
+      reinterpret_cast<const EagerTensor*>(tensor);
+  tensorflow::int64 result = TFE_TensorHandleNumElements(
+      as_c_eager_tensor->handle, as_c_eager_tensor->status);
+
+  if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status,
+                                      PyExc_ValueError)) {
+    // Cleanup status before returning.
+    TF_SetStatus(as_c_eager_tensor->status, TF_OK, "");
+    return -1;
+  }
+
+  return result;
+}
+
 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
   if (!PyType_Check(base_class)) {
     PyErr_SetString(
diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h
index bc042eb..4eaa1ba 100644
--- a/tensorflow/python/eager/pywrap_tensor.h
+++ b/tensorflow/python/eager/pywrap_tensor.h
@@ -21,8 +21,9 @@
 #include "tensorflow/python/lib/core/numpy.h"
 
 bool EagerTensor_CheckExact(const PyObject* o);
-tensorflow::int64 EagerTensor_id(const PyObject* tensor);
-tensorflow::DataType EagerTensor_dtype(const PyObject* tensor);
+tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
+tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
+tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor);
 
 namespace tensorflow {
 TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 9f2f4e0..196e20e 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -17,6 +17,7 @@
 
 #include "tensorflow/python/eager/pywrap_tfe.h"
 
+#include "absl/types/variant.h"
 #include "tensorflow/c/c_api.h"
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/c/eager/c_api_internal.h"
@@ -860,7 +861,7 @@
 
 static tensorflow::int64 FastTensorId(PyObject* tensor) {
   if (EagerTensor_CheckExact(tensor)) {
-    return EagerTensor_id(tensor);
+    return PyEagerTensor_ID(tensor);
   }
   PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
   if (id_field == nullptr) {
@@ -873,7 +874,7 @@
 
 static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
   if (EagerTensor_CheckExact(tensor)) {
-    return EagerTensor_dtype(tensor);
+    return PyEagerTensor_Dtype(tensor);
   }
   PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
   if (dtype_field == nullptr) {
@@ -889,12 +890,239 @@
   return static_cast<tensorflow::DataType>(id);
 }
 
+class PyTapeTensor {
+ public:
+  PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+               const tensorflow::TensorShape& shape)
+      : id_(id), dtype_(dtype), shape_(shape) {}
+  PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+               PyObject* shape)
+      : id_(id), dtype_(dtype), shape_(shape) {
+    Py_INCREF(absl::get<1>(shape_));
+  }
+  PyTapeTensor(const PyTapeTensor& other) {
+    id_ = other.id_;
+    dtype_ = other.dtype_;
+    shape_ = other.shape_;
+    if (shape_.index() == 1) {
+      Py_INCREF(absl::get<1>(shape_));
+    }
+  }
+
+  ~PyTapeTensor() {
+    if (shape_.index() == 1) {
+      Py_DECREF(absl::get<1>(shape_));
+    }
+  }
+  PyObject* GetShape() const;
+  PyObject* GetDType() const { return PyLong_FromLong(dtype_); }
+  tensorflow::int64 GetID() const { return id_; }
+
+ private:
+  tensorflow::int64 id_;
+  tensorflow::DataType dtype_;
+  absl::variant<tensorflow::TensorShape, PyObject*> shape_;
+};
+
+class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
+                                                  PyTapeTensor> {
+ public:
+  explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
+    Py_INCREF(py_vspace_);
+  }
+
+  tensorflow::Status Initialize() {
+    num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
+    if (num_elements_ == nullptr) {
+      return tensorflow::errors::InvalidArgument("invalid vspace");
+    }
+    aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
+    if (aggregate_fn_ == nullptr) {
+      return tensorflow::errors::InvalidArgument("invalid vspace");
+    }
+    zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
+    if (zeros_fn_ == nullptr) {
+      return tensorflow::errors::InvalidArgument("invalid vspace");
+    }
+    ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
+    if (ones_fn_ == nullptr) {
+      return tensorflow::errors::InvalidArgument("invalid vspace");
+    }
+    graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
+    if (graph_shape_fn_ == nullptr) {
+      return tensorflow::errors::InvalidArgument("invalid vspace");
+    }
+    return tensorflow::Status::OK();
+  }
+
+  ~PyVSpace() override {
+    Py_XDECREF(num_elements_);
+    Py_XDECREF(aggregate_fn_);
+    Py_XDECREF(zeros_fn_);
+    Py_XDECREF(ones_fn_);
+    Py_XDECREF(graph_shape_fn_);
+
+    Py_DECREF(py_vspace_);
+  }
+
+  tensorflow::int64 NumElements(PyObject* tensor) const final {
+    if (EagerTensor_CheckExact(tensor)) {
+      return PyEagerTensor_NumElements(tensor);
+    }
+    PyObject* arglist =
+        Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
+    PyObject* result = PyEval_CallObject(num_elements_, arglist);
+    Py_DECREF(arglist);
+    if (result == nullptr) {
+      // The caller detects whether a python exception has been raised.
+      return -1;
+    }
+    tensorflow::int64 r = MakeInt(result);
+    Py_DECREF(result);
+    return r;
+  }
+
+  PyObject* AggregateGradients(
+      tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
+    PyObject* list = PyList_New(gradient_tensors.size());
+    for (int i = 0; i < gradient_tensors.size(); ++i) {
+      // Note: stealing a reference to the gradient tensors.
+      CHECK(gradient_tensors[i] != nullptr);
+      CHECK(gradient_tensors[i] != Py_None);
+      PyList_SET_ITEM(list, i,
+                      reinterpret_cast<PyObject*>(gradient_tensors[i]));
+    }
+    PyObject* arglist = Py_BuildValue("(O)", list);
+    CHECK(arglist != nullptr);
+    PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
+    Py_DECREF(arglist);
+    Py_DECREF(list);
+    return result;
+  }
+
+  void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
+
+  PyObject* Zeros(const PyTapeTensor& tensor) const final {
+    PyObject* py_shape = tensor.GetShape();
+    PyObject* py_dtype = tensor.GetDType();
+    PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+    PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
+    Py_DECREF(arg_list);
+    Py_DECREF(py_dtype);
+    Py_DECREF(py_shape);
+    return reinterpret_cast<PyObject*>(result);
+  }
+
+  PyObject* Ones(const PyTapeTensor& tensor) const final {
+    PyObject* py_shape = tensor.GetShape();
+    PyObject* py_dtype = tensor.GetDType();
+    PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+    PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
+    Py_DECREF(arg_list);
+    Py_DECREF(py_dtype);
+    Py_DECREF(py_shape);
+    return result;
+  }
+
+  PyObject* GraphShape(PyObject* tensor) const {
+    PyObject* arg_list = Py_BuildValue("(O)", tensor);
+    PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
+    Py_DECREF(arg_list);
+    return result;
+  }
+
+  tensorflow::Status CallBackwardFunction(
+      PyBackwardFunction* backward_function,
+      tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
+      std::vector<PyObject*>* result) const final {
+    PyObject* grads = PyTuple_New(output_gradients.size());
+    for (int i = 0; i < output_gradients.size(); ++i) {
+      if (output_gradients[i] == nullptr) {
+        Py_INCREF(Py_None);
+        PyTuple_SET_ITEM(grads, i, Py_None);
+      } else {
+        PyTuple_SET_ITEM(grads, i,
+                         reinterpret_cast<PyObject*>(output_gradients[i]));
+      }
+    }
+    PyObject* py_result = (*backward_function)(grads);
+    Py_DECREF(grads);
+    if (py_result == nullptr) {
+      return tensorflow::errors::Internal("gradient function threw exceptions");
+    }
+    result->clear();
+    PyObject* seq =
+        PySequence_Fast(py_result, "expected a sequence of gradients");
+    if (seq == nullptr) {
+      return tensorflow::errors::InvalidArgument(
+          "gradient function did not return a list");
+    }
+    int len = PySequence_Fast_GET_SIZE(seq);
+    VLOG(1) << "Gradient length is " << len;
+    result->reserve(len);
+    for (int i = 0; i < len; ++i) {
+      PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
+      if (item == Py_None) {
+        result->push_back(nullptr);
+      } else {
+        Py_INCREF(item);
+        result->push_back(item);
+      }
+    }
+    Py_DECREF(seq);
+    Py_DECREF(py_result);
+    return tensorflow::Status::OK();
+  }
+
+  void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
+
+ private:
+  PyObject* py_vspace_;
+
+  PyObject* num_elements_;
+  PyObject* aggregate_fn_;
+  PyObject* zeros_fn_;
+  PyObject* ones_fn_;
+  PyObject* graph_shape_fn_;
+};
+PyVSpace* py_vspace = nullptr;
+
+PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
+  if (py_vspace != nullptr) {
+    delete py_vspace;
+  }
+
+  py_vspace = new PyVSpace(e);
+  auto status = py_vspace->Initialize();
+  if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+    delete py_vspace;
+    return nullptr;
+  }
+
+  Py_RETURN_NONE;
+}
+
+PyObject* PyTapeTensor::GetShape() const {
+  if (shape_.index() == 0) {
+    auto& shape = absl::get<0>(shape_);
+    PyObject* py_shape = PyTuple_New(shape.dims());
+    for (int i = 0; i < shape.dims(); ++i) {
+      PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
+    }
+
+    return py_shape;
+  }
+
+  return py_vspace->GraphShape(absl::get<1>(shape_));
+}
+
 class GradientTape
-    : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
+    : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+                                             PyTapeTensor> {
  public:
   explicit GradientTape(bool persistent, bool watch_accessed_variables)
-      : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>(
-            persistent),
+      : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+                                        PyTapeTensor>(persistent),
         watch_accessed_variables_(watch_accessed_variables) {}
 
   virtual ~GradientTape() {
@@ -1175,24 +1403,41 @@
   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
 }
 
-static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
+bool ListContainsNone(PyObject* list) {
+  if (list == Py_None) return true;
+  tensorflow::Safe_PyObjectPtr seq(
+      PySequence_Fast(list, "expected a sequence"));
+  if (seq == nullptr) {
+    return false;
+  }
+
+  int len = PySequence_Size(list);
+  for (int i = 0; i < len; ++i) {
+    PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
+    if (item == Py_None) return true;
+  }
+
+  return false;
+}
+
+static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
   if (EagerTensor_CheckExact(tensor)) {
     TFE_TensorHandle* t = EagerTensor_Handle(tensor);
-    tensorflow::int64 id = EagerTensor_id(tensor);
+    tensorflow::int64 id = PyEagerTensor_ID(tensor);
     tensorflow::TensorShape tensor_shape;
     const tensorflow::Status status = t->handle->Shape(&tensor_shape);
 
     if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
-      return tensorflow::eager::TapeTensor{id, t->handle->dtype,
-                                           tensorflow::TensorShape({})};
+      return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+                          tensorflow::TensorShape({}));
     } else {
-      return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape};
+      return PyTapeTensor(id, t->handle->dtype, tensor_shape);
     }
   }
   tensorflow::int64 id = FastTensorId(tensor);
   if (PyErr_Occurred()) {
-    return tensorflow::eager::TapeTensor{
-        id, static_cast<tensorflow::DataType>(0), tensorflow::TensorShape({})};
+    return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+                        tensorflow::TensorShape({}));
   }
   PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
   PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
@@ -1200,16 +1445,21 @@
   tensorflow::DataType dtype =
       static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
   Py_DECREF(dtype_enum);
-  if (PyErr_Occurred() != nullptr) {
-    return tensorflow::eager::TapeTensor{id, dtype,
-                                         tensorflow::TensorShape({})};
+  if (PyErr_Occurred()) {
+    return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+                        tensorflow::TensorShape({}));
   }
   static char _shape_tuple[] = "_shape_tuple";
   PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr);
-  if (PyErr_Occurred() != nullptr) {
-    return tensorflow::eager::TapeTensor{id, dtype,
-                                         tensorflow::TensorShape({})};
+  if (PyErr_Occurred()) {
+    return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+                        tensorflow::TensorShape({}));
   }
+
+  if (ListContainsNone(shape_tuple)) {
+    return PyTapeTensor(id, dtype, tensor);
+  }
+
   auto l = MakeIntList(shape_tuple);
   Py_DECREF(shape_tuple);
   // Replace -1, which represents accidental Nones which can occur in graph mode
@@ -1220,7 +1470,7 @@
     }
   }
   tensorflow::TensorShape shape(l);
-  return tensorflow::eager::TapeTensor{id, dtype, shape};
+  return PyTapeTensor(id, dtype, shape);
 }
 
 std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
@@ -1286,7 +1536,7 @@
     const std::vector<tensorflow::DataType>& input_dtypes,
     const std::function<PyBackwardFunction*()>& backward_function_getter,
     const std::function<void(PyBackwardFunction*)>& backward_function_killer) {
-  std::vector<tensorflow::eager::TapeTensor> output_info;
+  std::vector<PyTapeTensor> output_info;
   PyObject* seq = PySequence_Fast(output_tensors,
                                   "expected a sequence of integer tensor ids");
   int len = PySequence_Size(output_tensors);
@@ -1362,177 +1612,6 @@
   }
 }
 
-class PyVSpace
-    : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> {
- public:
-  explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
-    Py_INCREF(py_vspace_);
-  }
-
-  tensorflow::Status Initialize() {
-    num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
-    if (num_elements_ == nullptr) {
-      return tensorflow::errors::InvalidArgument("invalid vspace");
-    }
-    aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
-    if (aggregate_fn_ == nullptr) {
-      return tensorflow::errors::InvalidArgument("invalid vspace");
-    }
-    zeros_ = PyObject_GetAttrString(py_vspace_, "zeros");
-    if (zeros_ == nullptr) {
-      return tensorflow::errors::InvalidArgument("invalid vspace");
-    }
-    ones_ =
-        PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_), "ones");
-    if (ones_ == nullptr) {
-      return tensorflow::errors::InvalidArgument("invalid vspace");
-    }
-    return tensorflow::Status::OK();
-  }
-
-  ~PyVSpace() override {
-    Py_XDECREF(num_elements_);
-    Py_XDECREF(aggregate_fn_);
-    Py_XDECREF(zeros_);
-    Py_XDECREF(ones_);
-
-    Py_DECREF(py_vspace_);
-  }
-
-  tensorflow::int64 NumElements(PyObject* tensor) const final {
-    PyObject* arglist =
-        Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
-    PyObject* result = PyEval_CallObject(num_elements_, arglist);
-    Py_DECREF(arglist);
-    if (result == nullptr) {
-      // The caller detects whether a python exception has been raised.
-      return -1;
-    }
-    tensorflow::int64 r = MakeInt(result);
-    Py_DECREF(result);
-    return r;
-  }
-
-  PyObject* AggregateGradients(
-      tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
-    PyObject* list = PyList_New(gradient_tensors.size());
-    for (int i = 0; i < gradient_tensors.size(); ++i) {
-      // Note: stealing a reference to the gradient tensors.
-      CHECK(gradient_tensors[i] != nullptr);
-      CHECK(gradient_tensors[i] != Py_None);
-      PyList_SET_ITEM(list, i,
-                      reinterpret_cast<PyObject*>(gradient_tensors[i]));
-    }
-    PyObject* arglist = Py_BuildValue("(O)", list);
-    CHECK(arglist != nullptr);
-    PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
-    Py_DECREF(arglist);
-    Py_DECREF(list);
-    return result;
-  }
-
-  void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
-
-  PyObject* Zeros(tensorflow::TensorShape shape,
-                  tensorflow::DataType dtype) const final {
-    PyObject* py_shape = PyTuple_New(shape.dims());
-    for (int i = 0; i < shape.dims(); ++i) {
-      PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
-    }
-    PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
-    PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
-    PyObject* result = PyEval_CallObject(zeros_, arg_list);
-    Py_DECREF(arg_list);
-    Py_DECREF(py_dtype);
-    Py_DECREF(py_shape);
-    return reinterpret_cast<PyObject*>(result);
-  }
-
-  PyObject* Ones(tensorflow::TensorShape shape,
-                 tensorflow::DataType dtype) const final {
-    PyObject* py_shape = PyTuple_New(shape.dims());
-    for (int i = 0; i < shape.dims(); ++i) {
-      PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
-    }
-    PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
-    PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
-    PyObject* result = PyEval_CallObject(ones_, arg_list);
-    Py_DECREF(arg_list);
-    Py_DECREF(py_dtype);
-    Py_DECREF(py_shape);
-    return result;
-  }
-
-  tensorflow::Status CallBackwardFunction(
-      PyBackwardFunction* backward_function,
-      tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
-      std::vector<PyObject*>* result) const final {
-    PyObject* grads = PyTuple_New(output_gradients.size());
-    for (int i = 0; i < output_gradients.size(); ++i) {
-      if (output_gradients[i] == nullptr) {
-        Py_INCREF(Py_None);
-        PyTuple_SET_ITEM(grads, i, Py_None);
-      } else {
-        PyTuple_SET_ITEM(grads, i,
-                         reinterpret_cast<PyObject*>(output_gradients[i]));
-      }
-    }
-    PyObject* py_result = (*backward_function)(grads);
-    Py_DECREF(grads);
-    if (py_result == nullptr) {
-      return tensorflow::errors::Internal("gradient function threw exceptions");
-    }
-    result->clear();
-    PyObject* seq =
-        PySequence_Fast(py_result, "expected a sequence of gradients");
-    if (seq == nullptr) {
-      return tensorflow::errors::InvalidArgument(
-          "gradient function did not return a list");
-    }
-    int len = PySequence_Fast_GET_SIZE(seq);
-    VLOG(1) << "Gradient length is " << len;
-    result->reserve(len);
-    for (int i = 0; i < len; ++i) {
-      PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
-      if (item == Py_None) {
-        result->push_back(nullptr);
-      } else {
-        Py_INCREF(item);
-        result->push_back(item);
-      }
-    }
-    Py_DECREF(seq);
-    Py_DECREF(py_result);
-    return tensorflow::Status::OK();
-  }
-
-  void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
-
- private:
-  PyObject* py_vspace_;
-
-  PyObject* num_elements_;
-  PyObject* aggregate_fn_;
-  PyObject* zeros_;
-  PyObject* ones_;
-};
-PyVSpace* py_vspace = nullptr;
-
-PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
-  if (py_vspace != nullptr) {
-    delete py_vspace;
-  }
-
-  py_vspace = new PyVSpace(e);
-  auto status = py_vspace->Initialize();
-  if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
-    delete py_vspace;
-    return nullptr;
-  }
-
-  Py_RETURN_NONE;
-}
-
 std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
   if (seq == nullptr) {
@@ -1744,6 +1823,9 @@
   Py_RETURN_NONE;
 }
 
+// TODO(agarwal): use an automatic mechanism for handling None arguments to
+// gradient functions.
+
 // Returns a pair where the first value of the pair indicates whether or not all
 // outputs are unused. If the first value is false, the second value is a
 // set that identifies which of the output indices are unused.
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index bfcc019..7f23499 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -197,6 +197,7 @@
     srcs = ["canned/boosted_trees.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":boosted_trees_utils",
         ":estimator",
         ":head",
         ":model_fn",
@@ -224,6 +225,35 @@
 )
 
 py_library(
+    name = "boosted_trees_utils",
+    srcs = ["canned/boosted_trees_utils.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":estimator",
+        ":head",
+        ":model_fn",
+        "//tensorflow:tensorflow_py_no_contrib",
+    ],
+)
+
+py_test(
+    name = "boosted_trees_utils_test",
+    size = "medium",
+    srcs = ["canned/boosted_trees_utils_test.py"],
+    shard_count = 2,
+    srcs_version = "PY2AND3",
+    tags = [
+        "optonly",
+    ],
+    deps = [
+        ":boosted_trees",
+        ":inputs",
+        "//tensorflow:tensorflow_py_no_contrib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_library(
     name = "dnn",
     srcs = ["canned/dnn.py"],
     srcs_version = "PY2AND3",
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 19f1801..756d32d 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -22,7 +22,8 @@
 import functools
 
 from tensorflow.python.estimator import estimator
-from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.canned import boosted_trees_utils
 from tensorflow.python.estimator.canned import head as head_lib
 from tensorflow.python.feature_column import feature_column as feature_column_lib
 from tensorflow.python.framework import dtypes
@@ -36,6 +37,7 @@
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.array_ops import identity as tf_identity
 from tensorflow.python.ops.losses import losses
 from tensorflow.python.summary import summary
 from tensorflow.python.training import session_run_hook
@@ -197,8 +199,7 @@
   cached_features = [
       _local_variable(
           array_ops.zeros([batch_size], dtype=dtypes.int32),
-          name='cached_feature_{}'.format(i))
-      for i in range(num_features)
+          name='cached_feature_{}'.format(i)) for i in range(num_features)
   ]
   are_features_cached = _local_variable(False, name='are_features_cached')
 
@@ -228,8 +229,7 @@
     return cached, cache_flip_op
 
   input_feature_list, cache_flip_op = control_flow_ops.cond(
-      are_features_cached,
-      lambda: (cached_features, control_flow_ops.no_op()),
+      are_features_cached, lambda: (cached_features, control_flow_ops.no_op()),
       cache_features_and_return)
   return input_feature_list, cache_flip_op
 
@@ -263,8 +263,8 @@
     elif dtypes.as_dtype(dtypes.string).is_compatible_with(example_ids.dtype):
       empty_key = ''
     else:
-      raise ValueError('Unsupported example_id_feature dtype %s.' %
-                       example_ids.dtype)
+      raise ValueError(
+          'Unsupported example_id_feature dtype %s.' % example_ids.dtype)
     # Cache holds latest <tree_id, node_id, logits> for each example.
     # tree_id and node_id are both int32 but logits is a float32.
     # To reduce the overhead, we store all of them together as float32 and
@@ -273,8 +273,8 @@
         empty_key=empty_key, value_dtype=dtypes.float32, value_shape=[3])
     self._example_ids = ops.convert_to_tensor(example_ids)
     if self._example_ids.shape.ndims not in (None, 1):
-      raise ValueError('example_id should have rank 1, but got %s' %
-                       self._example_ids)
+      raise ValueError(
+          'example_id should have rank 1, but got %s' % self._example_ids)
     self._logits_dimension = logits_dimension
 
   def lookup(self):
@@ -334,7 +334,7 @@
         array_ops.zeros([batch_size], dtype=dtypes.int32),
         name='tree_ids_cache')
     self._node_ids = _local_variable(
-        _DUMMY_NODE_ID*array_ops.ones([batch_size], dtype=dtypes.int32),
+        _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
         name='node_ids_cache')
     self._logits = _local_variable(
         array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32),
@@ -422,9 +422,13 @@
     self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
         tree_hparams.pruning_mode)
 
-    if (self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING
-        and tree_hparams.tree_complexity <= 0):
-      raise ValueError('For pruning, tree_complexity must be positive.')
+    if tree_hparams.tree_complexity > 0:
+      if self._pruning_mode_parsed == boosted_trees_ops.PruningMode.NO_PRUNING:
+        raise ValueError(
+            'Tree complexity have no effect unless pruning mode is chosen.')
+    else:
+      if self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING:
+        raise ValueError('For pruning, tree_complexity must be positive.')
     # pylint: enable=protected-access
 
   @abc.abstractmethod
@@ -719,7 +723,7 @@
     tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
 
     # Create logits.
-    if mode != model_fn.ModeKeys.TRAIN:
+    if mode != model_fn_lib.ModeKeys.TRAIN:
       input_feature_list = _get_transformed_features(features,
                                                      sorted_feature_columns)
       logits = boosted_trees_ops.predict(
@@ -886,6 +890,7 @@
       labels=labels,
       train_op_fn=_train_op_fn,
       logits=logits)
+
   # Add an early stop hook.
   estimator_spec = estimator_spec._replace(
       training_hooks=estimator_spec.training_hooks +
@@ -927,8 +932,8 @@
                                                 label_vocabulary):
   """Creates a head for classifier and the closed form gradients/hessians."""
   head = _create_classification_head(n_classes, weight_column, label_vocabulary)
-  if (n_classes == 2 and head.logits_dimension == 1 and weight_column is None
-      and label_vocabulary is None):
+  if (n_classes == 2 and head.logits_dimension == 1 and
+      weight_column is None and label_vocabulary is None):
     # Use the closed-form gradients/hessians for 2 class.
     def _grad_and_hess_for_logloss(logits, labels):
       """A closed form gradient and hessian for logistic loss."""
@@ -961,8 +966,196 @@
   # pylint: enable=protected-access
 
 
+def _bt_explanations_fn(features,
+                        head,
+                        sorted_feature_columns,
+                        name='boosted_trees'):
+  """Gradient Boosted Trees predict with explanations model_fn.
+
+  Args:
+    features: dict of `Tensor`.
+    head: A `head_lib._Head` instance.
+    sorted_feature_columns: Sorted iterable of `feature_column._FeatureColumn`
+      model inputs.
+    name: Name used for the model.
+
+  Returns:
+      An `EstimatorSpec` instance.
+
+  Raises:
+    ValueError: mode or params are invalid, or features has the wrong type.
+  """
+  mode = model_fn_lib.ModeKeys.PREDICT
+  with ops.name_scope(name) as name:
+    # Create Ensemble resources.
+    tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+
+    input_feature_list = _get_transformed_features(features,
+                                                   sorted_feature_columns)
+
+    logits = boosted_trees_ops.predict(
+        # For non-TRAIN mode, ensemble doesn't change after initialization,
+        # so no local copy is needed; using tree_ensemble directly.
+        tree_ensemble_handle=tree_ensemble.resource_handle,
+        bucketized_features=input_feature_list,
+        logits_dimension=head.logits_dimension)
+
+    estimator_spec = head.create_estimator_spec(
+        features=features,
+        mode=mode,
+        labels=None,
+        train_op_fn=control_flow_ops.no_op,
+        logits=logits)
+
+    debug_op = boosted_trees_ops.example_debug_outputs(
+        tree_ensemble.resource_handle,
+        bucketized_features=input_feature_list,
+        logits_dimension=head.logits_dimension)
+    estimator_spec.predictions[boosted_trees_utils._DEBUG_PROTO_KEY] = debug_op  # pylint: disable=protected-access
+    return estimator_spec
+
+
+class _BoostedTreesBase(estimator.Estimator):
+  """Base class for boosted trees estimators.
+
+  This class is intended to keep tree-specific functions (E.g., methods for
+  feature importances and directional feature contributions) in one central
+  place.
+
+  It is not a valid (working) Estimator on its own and should only be used as a
+  base class.
+  """
+
+  def __init__(self, model_fn, model_dir, config, feature_columns, head,
+               center_bias, is_classification):
+    """Initializes a `_BoostedTreesBase` instance.
+
+    Args:
+      model_fn: model_fn: Model function. See base class for more detail.
+      model_dir: Directory to save model parameters, graph and etc. See base
+        class for more detail.
+      config: `estimator.RunConfig` configuration object.
+      feature_columns: An iterable containing all the feature columns used by
+        the model. All items in the set should be instances of classes derived
+        from `FeatureColumn`
+      head: A `head_lib._Head` instance.
+      center_bias: Whether bias centering needs to occur. Bias centering refers
+        to the first node in the very first tree returning the prediction that
+        is aligned with the original labels distribution. For example, for
+        regression problems, the first node will return the mean of the labels.
+        For binary classification problems, it will return a logit for a prior
+        probability of label 1.
+      is_classification: If the estimator is for classification.
+    """
+    super(_BoostedTreesBase, self).__init__(
+        model_fn=model_fn, model_dir=model_dir, config=config)
+    self._sorted_feature_columns = sorted(
+        feature_columns, key=lambda tc: tc.name)
+    self._head = head
+    self._n_features = _calculate_num_features(self._sorted_feature_columns)
+    self._center_bias = center_bias
+    self._is_classification = is_classification
+
+  def experimental_predict_with_explanations(self,
+                                             input_fn,
+                                             predict_keys=None,
+                                             hooks=None,
+                                             checkpoint_path=None):
+    """Computes model explainability outputs per example along with predictions.
+
+    Currently supports directional feature contributions (DFCs). For each
+    instance, DFCs indicate the aggregate contribution of each feature. See
+    https://arxiv.org/abs/1312.1121 and
+    http://blog.datadive.net/interpreting-random-forests/ for more details.
+    Args:
+      input_fn: A function that provides input data for predicting as
+        minibatches. See [Premade Estimators](
+        https://tensorflow.org/guide/premade_estimators#create_input_functions)
+          for more information. The function should construct and return one of
+        the following:  * A `tf.data.Dataset` object: Outputs of `Dataset`
+          object must be a tuple `(features, labels)` with same constraints as
+        below. * A tuple `(features, labels)`: Where `features` is a `tf.Tensor`
+          or a dictionary of string feature name to `Tensor` and `labels` is a
+          `Tensor` or a dictionary of string label name to `Tensor`. Both
+          `features` and `labels` are consumed by `model_fn`. They should
+          satisfy the expectation of `model_fn` from inputs.
+      predict_keys: list of `str`, name of the keys to predict. It is used if
+        the `tf.estimator.EstimatorSpec.predictions` is a `dict`. If
+        `predict_keys` is used then rest of the predictions will be filtered
+        from the dictionary, with the exception of 'bias' and 'dfc', which will
+        always be in the dictionary. If `None`, returns all keys in prediction
+        dict, as well as two new keys 'dfc' and 'bias'.
+      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+        callbacks inside the prediction call.
+      checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
+        latest checkpoint in `model_dir` is used.  If there are no checkpoints
+        in `model_dir`, prediction is run with newly initialized `Variables`
+        instead of ones restored from checkpoint.
+
+    Yields:
+      Evaluated values of `predictions` tensors. The `predictions` tensors will
+      contain at least two keys 'dfc' and 'bias' for model explanations. The
+      `dfc` value corresponds to the contribution of each feature to the overall
+      prediction for this instance (positive indicating that the feature makes
+      it more likely to select class 1 and negative less likely). The 'bias'
+      value will be the same across all the instances, corresponding to the
+      probability (classification) or prediction (regression) of the training
+      data distribution.
+
+    Raises:
+      ValueError: when wrong arguments are given or unsupported functionalities
+       are requested.
+    """
+    if not self._center_bias:
+      raise ValueError('center_bias must be enabled during estimator '
+                       'instantiation when using '
+                       'experimental_predict_with_explanations.')
+    # pylint: disable=protected-access
+    if not self._is_classification:
+      identity_inverse_link_fn = self._head._inverse_link_fn in (None,
+                                                                 tf_identity)
+      # pylint:enable=protected-access
+      if not identity_inverse_link_fn:
+        raise ValueError(
+            'For now only identity inverse_link_fn in regression_head is '
+            'supported for experimental_predict_with_explanations.')
+
+    # pylint:disable=unused-argument
+    def new_model_fn(features, labels, mode):
+      return _bt_explanations_fn(features, self._head,
+                                 self._sorted_feature_columns)
+
+    # pylint:enable=unused-argument
+    est = estimator.Estimator(
+        model_fn=new_model_fn,
+        model_dir=self.model_dir,
+        config=self.config,
+        warm_start_from=self._warm_start_settings)
+    # Make sure bias and dfc will be in prediction dict.
+    user_supplied_predict_keys = predict_keys is not None
+    if user_supplied_predict_keys:
+      predict_keys = set(predict_keys)
+      predict_keys.add(boosted_trees_utils._DEBUG_PROTO_KEY)
+    predictions = est.predict(
+        input_fn,
+        predict_keys=predict_keys,
+        hooks=hooks,
+        checkpoint_path=checkpoint_path,
+        yield_single_examples=True)
+    for pred in predictions:
+      bias, dfcs = boosted_trees_utils._parse_explanations_from_prediction(
+          pred[boosted_trees_utils._DEBUG_PROTO_KEY], self._n_features,
+          self._is_classification)
+      pred['bias'] = bias
+      pred['dfc'] = dfcs
+      # Don't need to expose serialized proto to end user.
+      del pred[boosted_trees_utils._DEBUG_PROTO_KEY]
+      yield pred
+
+
+# pylint: disable=protected-access
 @estimator_export('estimator.BoostedTreesClassifier')
-class BoostedTreesClassifier(estimator.Estimator):
+class BoostedTreesClassifier(_BoostedTreesBase):
   """A Classifier for Tensorflow Boosted Trees models.
 
   @compatibility(eager)
@@ -1082,14 +1275,13 @@
       n_classes = 2
     head, closed_form = _create_classification_head_and_closed_form(
         n_classes, weight_column, label_vocabulary=label_vocabulary)
-
     # HParams for the model.
     tree_hparams = _TreeHParams(
         n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
         tree_complexity, min_node_weight, center_bias, pruning_mode)
 
     def _model_fn(features, labels, mode, config):
-      return _bt_model_fn(  # pylint: disable=protected-access
+      return _bt_model_fn(
           features,
           labels,
           mode,
@@ -1101,11 +1293,17 @@
           closed_form_grad_and_hess_fn=closed_form)
 
     super(BoostedTreesClassifier, self).__init__(
-        model_fn=_model_fn, model_dir=model_dir, config=config)
+        model_fn=_model_fn,
+        model_dir=model_dir,
+        config=config,
+        feature_columns=feature_columns,
+        head=head,
+        center_bias=center_bias,
+        is_classification=True)
 
 
 @estimator_export('estimator.BoostedTreesRegressor')
-class BoostedTreesRegressor(estimator.Estimator):
+class BoostedTreesRegressor(_BoostedTreesBase):
   """A Regressor for Tensorflow Boosted Trees models.
 
   @compatibility(eager)
@@ -1223,9 +1421,17 @@
         tree_complexity, min_node_weight, center_bias, pruning_mode)
 
     def _model_fn(features, labels, mode, config):
-      return _bt_model_fn(  # pylint: disable=protected-access
-          features, labels, mode, head, feature_columns, tree_hparams,
-          n_batches_per_layer, config)
+      return _bt_model_fn(features, labels, mode, head, feature_columns,
+                          tree_hparams, n_batches_per_layer, config)
 
     super(BoostedTreesRegressor, self).__init__(
-        model_fn=_model_fn, model_dir=model_dir, config=config)
+        model_fn=_model_fn,
+        model_dir=model_dir,
+        config=config,
+        feature_columns=feature_columns,
+        head=head,
+        center_bias=center_bias,
+        is_classification=False)
+
+
+# pylint: enable=protected-access
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 6e28c72..d4cb3e2 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -564,6 +564,175 @@
     self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id)
     self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold)
 
+  def testTreeComplexityIsSetCorrectly(self):
+    input_fn = _make_train_input_fn(is_classification=True)
+
+    num_steps = 10
+    # Tree complexity is set but no pruning.
+    est = boosted_trees.BoostedTreesClassifier(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5,
+        tree_complexity=1e-3)
+    with self.assertRaisesRegexp(ValueError, 'Tree complexity have no effect'):
+      est.train(input_fn, steps=num_steps)
+
+    # Pruning but no tree complexity.
+    est = boosted_trees.BoostedTreesClassifier(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5,
+        pruning_mode='pre')
+    with self.assertRaisesRegexp(ValueError,
+                                 'tree_complexity must be positive'):
+      est.train(input_fn, steps=num_steps)
+
+    # All is good.
+    est = boosted_trees.BoostedTreesClassifier(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5,
+        pruning_mode='pre',
+        tree_complexity=1e-3)
+    est.train(input_fn, steps=num_steps)
+
+
+class BoostedTreesDebugOutputsTest(test_util.TensorFlowTestCase):
+  """Test debug/model explainability outputs for individual predictions.
+
+  Includes directional feature contributions (DFC).
+  """
+
+  def setUp(self):
+    self._feature_columns = {
+        feature_column.bucketized_column(
+            feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+            BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
+    }
+
+  def testBinaryClassifierThatDFCIsInPredictions(self):
+    train_input_fn = _make_train_input_fn(is_classification=True)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=3, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.BoostedTreesClassifier(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5,
+        center_bias=True)
+
+    num_steps = 100
+    # Train for a few steps. Validate debug outputs in prediction dicts.
+    est.train(train_input_fn, steps=num_steps)
+    debug_predictions = est.experimental_predict_with_explanations(
+        predict_input_fn)
+    biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+                         for pred in debug_predictions])
+    self.assertAllClose([0.4] * 5, biases)
+    self.assertAllClose(({
+        0: -0.12108613453574479,
+        1: 0.0,
+        2: -0.039254929814481143
+    }, {
+        0: 0.19650601422250574,
+        1: 0.0,
+        2: 0.02693827052766018
+    }, {
+        0: 0.16057487356133376,
+        1: 0.0,
+        2: 0.02693827052766018
+    }, {
+        0: -0.12108613453574479,
+        1: 0.0,
+        2: -0.039254929814481143
+    }, {
+        0: -0.10832468554550384,
+        1: 0.0,
+        2: 0.02693827052766018
+    }), dfcs)
+
+    # Assert sum(dfcs) + bias == probabilities.
+    expected_probabilities = [
+        0.23965894, 0.62344426, 0.58751315, 0.23965894, 0.31861359
+    ]
+    probabilities = [
+        sum(dfc.values()) + bias for (dfc, bias) in zip(dfcs, biases)
+    ]
+    self.assertAllClose(expected_probabilities, probabilities)
+
+    # When user doesn't include bias or dfc in predict_keys, make sure to still
+    # include dfc and bias.
+    debug_predictions = est.experimental_predict_with_explanations(
+        predict_input_fn, predict_keys=['probabilities'])
+    for prediction_dict in debug_predictions:
+      self.assertTrue('bias' in prediction_dict)
+      self.assertTrue('dfc' in prediction_dict)
+      self.assertTrue('probabilities' in prediction_dict)
+      self.assertEqual(len(prediction_dict), 3)
+
+  def testRegressorThatDFCIsInPredictions(self):
+    train_input_fn = _make_train_input_fn(is_classification=False)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+    est = boosted_trees.BoostedTreesRegressor(
+        feature_columns=self._feature_columns,
+        n_batches_per_layer=1,
+        n_trees=1,
+        max_depth=5,
+        center_bias=True)
+
+    num_steps = 100
+    # Train for a few steps. Validate debug outputs in prediction dicts.
+    est.train(train_input_fn, steps=num_steps)
+    debug_predictions = est.experimental_predict_with_explanations(
+        predict_input_fn)
+    biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+                         for pred in debug_predictions])
+    self.assertAllClose([1.8] * 5, biases)
+    self.assertAllClose(({
+        0: -0.070499420166015625,
+        1: -0.095000028610229492,
+        2: 0.0
+    }, {
+        0: -0.53763031959533691,
+        1: 0.063333392143249512,
+        2: 0.0
+    }, {
+        0: -0.51756942272186279,
+        1: -0.095000028610229492,
+        2: 0.0
+    }, {
+        0: 0.1563495397567749,
+        1: 0.063333392143249512,
+        2: 0.0
+    }, {
+        0: 0.96934974193572998,
+        1: 0.063333392143249512,
+        2: 0.0
+    }), dfcs)
+
+    # Assert sum(dfcs) + bias == predictions.
+    expected_predictions = [[1.6345005], [1.32570302], [1.1874305],
+                            [2.01968288], [2.83268309]]
+    predictions = [
+        [sum(dfc.values()) + bias] for (dfc, bias) in zip(dfcs, biases)
+    ]
+    self.assertAllClose(expected_predictions, predictions)
+
+    # Test when user doesn't include bias or dfc in predict_keys.
+    debug_predictions = est.experimental_predict_with_explanations(
+        predict_input_fn, predict_keys=['predictions'])
+    for prediction_dict in debug_predictions:
+      self.assertTrue('bias' in prediction_dict)
+      self.assertTrue('dfc' in prediction_dict)
+      self.assertTrue('predictions' in prediction_dict)
+      self.assertEqual(len(prediction_dict), 3)
+
 
 class ModelFnTests(test_util.TensorFlowTestCase):
   """Tests bt_model_fn including unexposed internal functionalities."""
diff --git a/tensorflow/python/estimator/canned/boosted_trees_utils.py b/tensorflow/python/estimator/canned/boosted_trees_utils.py
new file mode 100644
index 0000000..85efc23
--- /dev/null
+++ b/tensorflow/python/estimator/canned/boosted_trees_utils.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Debug and model explainability logic for boosted trees."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+
+# For directional feature contributions.
+_DEBUG_PROTO_KEY = '_serialized_debug_outputs_proto'
+_BIAS_ID = 0
+
+
+def _parse_debug_proto_string(example_proto_serialized):
+  example_debug_outputs = boosted_trees_pb2.DebugOutput()
+  example_debug_outputs.ParseFromString(example_proto_serialized)
+  feature_ids = example_debug_outputs.feature_ids
+  logits_path = example_debug_outputs.logits_path
+  return feature_ids, logits_path
+
+
+def _compute_directional_feature_contributions(example_feature_ids,
+                                               example_logits_paths, activation,
+                                               num_bucketized_features):
+  """Directional feature contributions and bias, per example."""
+  # Initialize contributions to 0.
+  dfcs = {k: 0 for k in range(num_bucketized_features)}
+
+  # Traverse tree subtracting child prediction from parent prediction and
+  # associating change with feature id used to split.
+  predictions = np.array(activation(example_logits_paths))
+  delta_pred = predictions[_BIAS_ID + 1:] - predictions[:-1]
+  # Group by feature id, then sum delta_pred.
+  contribs = np.bincount(
+      example_feature_ids,
+      weights=delta_pred,
+      minlength=num_bucketized_features)
+  for f, dfc in zip(range(num_bucketized_features), contribs):
+    dfcs[f] = dfc
+  return predictions[_BIAS_ID], dfcs
+
+
+def _identity(logits):
+  return logits
+
+
+def _sigmoid(logits):
+  # TODO(crawles): Change to softmax once multiclass support is available.
+  return 1 / (1 + np.exp(-np.array(logits)))
+
+
+def _parse_explanations_from_prediction(serialized_debug_proto,
+                                        n_features,
+                                        classification=False):
+  """Parse serialized explanability proto, compute dfc, and return bias, dfc."""
+  feature_ids, logits_path = _parse_debug_proto_string(serialized_debug_proto)
+  if classification:
+    activation = _sigmoid
+  else:
+    activation = _identity
+  bias, dfcs = _compute_directional_feature_contributions(
+      feature_ids, logits_path, activation, n_features)
+  # TODO(crawles): Prediction path and leaf IDs.
+  return bias, dfcs
diff --git a/tensorflow/python/estimator/canned/boosted_trees_utils_test.py b/tensorflow/python/estimator/canned/boosted_trees_utils_test.py
new file mode 100644
index 0000000..506d4ea
--- /dev/null
+++ b/tensorflow/python/estimator/canned/boosted_trees_utils_test.py
@@ -0,0 +1,187 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests boosted_trees estimators and model_fn."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.estimator.canned import boosted_trees_utils
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class BoostedTreesDFCTest(test_util.TensorFlowTestCase):
+  """Test directional feature contributions (DFC) helper functions. """
+
+  def testDirectionalFeatureContributionsCompute(self):
+    """Tests logic to compute DFCs given feature ids and logits paths."""
+    num_bucketized_features = 3  # Includes one unused feature.
+    examples_feature_ids = ((2, 2, 0, 0), (2, 2, 0))
+    e1_feature_ids, e2_feature_ids = examples_feature_ids
+
+    # DFCs are computed by traversing the prediction path and subtracting each
+    # child prediction from its parent prediction and associating the change in
+    # prediction with the respective feature id used for the split.
+    # For each activation function, f, (currently identity or sigmoid), DFCs are
+    # calculated for the two examples as:
+    # example 1:
+    #   feature_0 = (f(1.114) - f(1.214)) + (f(6.114) - f(1.114))
+    #   feature_1 = 0  # Feature not in ensemble, thus zero contrib.
+    #   feature_2 = (f(0.114) - bias_pred) + (f(1.214) - f(0.114))
+    # example 2:
+    #   feature_0 = f(-5.486) - f(1.514)
+    #   feature_1 = 0  # Feature not in ensemble, thus zero contrib.
+    #   feature_2 = (f(0.114) - bias_pred) + (f(1.514) - f(0.114))
+    # where bias_pred is = f(0) or f(0.21), with center_bias = {True, False},
+    # respectively.
+    # Keys are center_bias.
+    expected_dfcs_identity = {
+        False: ({
+            0: 4.9,
+            1: 0,
+            2: 1.214
+        }, {
+            0: -7.0,
+            1: 0,
+            2: 1.514
+        }),
+        True: ({
+            0: 4.9,
+            1: 0,
+            2: 1.0039999999999998
+        }, {
+            0: -7.0,
+            1: 0,
+            2: 1.3039999999999998
+        })
+    }
+    expected_dfcs_sigmoid = {
+        False: ({
+            0: 0.22678725678805578,
+            1: 0,
+            2: 0.2710059376234506
+        }, {
+            0: -0.81552596670046507,
+            1: 0,
+            2: 0.319653250251275
+        }),
+        True: ({
+            0: 0.22678725678805578,
+            1: 0,
+            2: 0.2186980280491253
+        }, {
+            0: -0.81552596670046507,
+            1: 0,
+            2: 0.26734534067694971
+        })
+    }
+    # pylint: disable=protected-access
+    for f, expected_dfcs in zip(
+        (boosted_trees_utils._identity, boosted_trees_utils._sigmoid),
+        (expected_dfcs_identity, expected_dfcs_sigmoid)):
+      for center_bias in [False, True]:
+        # If not center_bias, the bias after activation is 0.
+        if center_bias:
+          bias_logit = 0.21  # Root node of tree_0.
+        else:
+          bias_logit = 0  # 0 is default value when there is no original_leaf.
+        f_bias = f(bias_logit)
+
+        # Logits before and after, as is outputed from
+        # boosted_trees_ops.example_debug_outputs
+        examples_logits_paths = ((bias_logit, 0.114, 1.214, 1.114, 6.114),
+                                 (bias_logit, 0.114, 1.514, -5.486))
+        e1_logits_path, e2_logits_path = examples_logits_paths
+        e1_expected_dfcs, e2_expected_dfcs = expected_dfcs[center_bias]
+        # Check feature contributions are correct for both examples.
+        # Example 1.
+        # pylint:disable=line-too-long
+        e1_bias, e1_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+            e1_feature_ids, e1_logits_path, f, num_bucketized_features)
+        self.assertAllClose(e1_bias, f_bias)
+        self.assertAllClose(e1_dfc, e1_expected_dfcs)
+        # Example 2.
+        e2_bias, e2_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+            e2_feature_ids, e2_logits_path, f, num_bucketized_features)
+        # pylint:enable=line-too-long
+        self.assertAllClose(e2_bias, f_bias)
+        self.assertAllClose(e2_dfc, e2_expected_dfcs)
+        # Check if contributions sum to final prediction.
+        # For each tree, get leaf of last tree.
+        expected_logits = (e1_logits_path[-1], e2_logits_path[-1])
+        # Predictions should be the sum of contributions + bias.
+        expected_preds = [f(logit) for logit in expected_logits]
+        e1_pred = e1_bias + sum(e1_dfc.values())
+        e2_pred = e2_bias + sum(e2_dfc.values())
+        preds = [e1_pred, e2_pred]
+        self.assertAllClose(preds, expected_preds)
+    # pylint: enable=protected-access
+
+  def testDFCComputeComparedToExternalExample(self):
+    """Tests `compute_dfc` compared to external example (regression).
+
+    Example from http://blog.datadive.net/interpreting-random-forests.
+    """
+    # DIS:3, RM: 2, LSTAT:1, NOX:0
+    num_bucketized_features = 4
+    e1_feature_ids = (2, 1, 0)
+    e2_feature_ids = (2, 2, 2)
+    e3_feature_ids = (2, 2, 0)
+
+    bias_logit = 22.60  # Root node of tree_0.
+    activation = boosted_trees_utils._identity
+    f_bias = activation(bias_logit)
+    # Logits before and after, as is outputed from
+    # boosted_trees_ops.example_debug_outputs
+    e1_logits_path = (bias_logit, 19.96, 14.91, 18.11)
+    e2_logits_path = (bias_logit, 37.42, 45.10, 45.90)
+    e3_logits_path = (bias_logit, 37.42, 32.30, 33.58)
+    e1_expected_dfcs = {0: 3.20, 1: -5.05, 2: -2.64, 3: 0}
+    e2_expected_dfcs = {0: 0, 1: 0, 2: 23.3, 3: 0}
+    e3_expected_dfcs = {0: 1.28, 1: 0, 2: 9.7, 3: 0}
+    # Check feature contributions are correct for both examples.
+    # Example 1.
+    # pylint: disable=protected-access
+    # pylint: disable=line-too-long
+    e1_bias, e1_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+        e1_feature_ids, e1_logits_path, activation, num_bucketized_features)
+    self.assertAllClose(e1_bias, f_bias)
+    self.assertAllClose(e1_dfc, e1_expected_dfcs)
+    # Example 2.
+    e2_bias, e2_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+        e2_feature_ids, e2_logits_path, activation, num_bucketized_features)
+    self.assertAllClose(e2_bias, f_bias)
+    self.assertAllClose(e2_dfc, e2_expected_dfcs)
+    # Example 3.
+    e3_bias, e3_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+        e3_feature_ids, e3_logits_path, activation, num_bucketized_features)
+    # pylint: enable=line-too-long
+    self.assertAllClose(e3_bias, f_bias)
+    self.assertAllClose(e3_dfc, e3_expected_dfcs)
+    # pylint: enable=protected-access
+    # Check if contributions sum to final prediction.
+    # For each tree, get leaf of last tree.
+    expected_logits = (18.11, 45.90, 33.58)
+    # Predictions should be the sum of contributions + bias.
+    expected_preds = [activation(logit) for logit in expected_logits]
+    e1_pred = e1_bias + sum(e1_dfc.values())
+    e2_pred = e2_bias + sum(e2_dfc.values())
+    e3_pred = e3_bias + sum(e3_dfc.values())
+    preds = [e1_pred, e2_pred, e3_pred]
+    self.assertAllClose(preds, expected_preds)
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 0f20ace..2dc5d09 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -41,7 +41,6 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
 from tensorflow.python.framework import tensor_util
-from tensorflow.python.keras import metrics
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import metrics as metrics_lib
@@ -329,7 +328,7 @@
                                  run_config.TaskType.PS):
       raise ValueError(
           'Train has been called wrong configuration. Please use '
-          'tf.estimator.train_and_evaluate which calls propper API according '
+          'tf.estimator.train_and_evaluate which calls proper API according '
           'to given configuration. Current configuration: {}.'.format(
               self.config))
 
@@ -490,6 +489,10 @@
               yield_single_examples=True):
     """Yields predictions for given features.
 
+    Please note that interleaving two predict outputs does not work. See:
+    [issue/20506](
+    https://github.com/tensorflow/tensorflow/issues/20506#issuecomment-422208517)
+
     Args:
       input_fn: A function that constructs the features. Prediction continues
         until `input_fn` raises an end-of-input exception
@@ -1653,7 +1656,7 @@
   def _unwrap_and_concat(value):
     value = nest.flatten(distribution.unwrap(value))
     if len(value) != 1:
-      return array_ops.concat(value)
+      return array_ops.concat(value, 0)
     return value[0]
 
   ready_op = distribution.call_for_each_tower(
@@ -1788,18 +1791,9 @@
   value_ops = {}
   # Sort metrics lexicographically so graph is identical every time.
   for name, value in sorted(six.iteritems(eval_dict)):
-    if isinstance(value, metrics.Metric):
-      metric_result = value.result()
-      # We expect only one update op for every metric when there is no
-      # distribution strategy.
-      metric_update = value.updates if distribution else value.updates[0]
-    else:
-      metric_result = value[0]
-      metric_update = value[1]
-
-    value_ops[name] = metric_result
+    value_ops[name] = value[0]
     update_ops.append(
-        distribution.group(metric_update) if distribution else metric_update)
+        distribution.group(value[1]) if distribution else value[1])
 
   update_op = control_flow_ops.group(*update_ops) if update_ops else None
   return update_op, value_ops
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index 3eed1ab..ed3219c 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -376,7 +376,7 @@
                         "  } "
                         "} ", example)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sparse_result = sess.run(
             serving_input_receiver.features,
             feed_dict={
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 439cc2e..8247894 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -308,6 +308,8 @@
     for key, value in six.iteritems(eval_metric_ops):
       if isinstance(value, Metric):
         vars_to_add.update(value.variables)
+        # Convert Metric instances to (value_tensor, update_op) tuple.
+        eval_metric_ops[key] = (value.result(), value.updates[0])
     # Remove variables that are in the local variables collection already.
     vars_to_add = vars_to_add.difference(local_vars)
     for v in vars_to_add:
@@ -466,13 +468,13 @@
 
 
 def _check_is_tensor_or_operation(x, name):
-  if not (isinstance(x, ops.Operation) or isinstance(x, ops.Tensor)):
+  if not (isinstance(x, ops.Operation) or ops.is_dense_tensor_like(x)):
     raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x))
 
 
 def _check_is_tensor(x, tensor_name):
   """Returns `x` if it is a `Tensor`, raises TypeError otherwise."""
-  if not isinstance(x, ops.Tensor):
+  if not ops.is_dense_tensor_like(x):
     raise TypeError('{} must be Tensor, given: {}'.format(tensor_name, x))
   return x
 
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index a8aef3a..68b3170 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -762,13 +762,12 @@
         if handle_data:
           handle_data = handle_data.SerializeToString()
       else:
-        handle_data = c_api.GetResourceHandleShapeAndType(
-            tensor.graph._c_graph, tensor._as_tf_output())
+        handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph,
+                                                  tensor._as_tf_output())
 
       if handle_data:
-        c_api.SetResourceHandleShapeAndType(ph.graph._c_graph,
-                                            ph._as_tf_output(),
-                                            compat.as_bytes(handle_data))
+        c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(),
+                                    compat.as_bytes(handle_data))
     else:
       ph._handle_data = tensor._handle_data
     # pylint: enable=protected-access
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
index 535c601..908a5f5 100644
--- a/tensorflow/python/framework/load_library.py
+++ b/tensorflow/python/framework/load_library.py
@@ -18,14 +18,18 @@
 from __future__ import division
 from __future__ import print_function
 
+import errno
 import hashlib
 import imp
+import os
+import platform
 import sys
 import threading  # pylint: disable=unused-import
 
 from tensorflow.core.framework import op_def_pb2
 from tensorflow.core.lib.core import error_codes_pb2  # pylint: disable=unused-import
 from tensorflow.python import pywrap_tensorflow as py_tf
+from tensorflow.python.lib.io import file_io
 from tensorflow.python.util import compat
 from tensorflow.python.util.tf_export import tf_export
 
@@ -98,3 +102,64 @@
     RuntimeError: when unable to load the library.
   """
   py_tf.TF_LoadLibrary(library_filename)
+
+
+def _is_shared_object(filename):
+  """Check the file to see if it is a shared object, only using extension."""
+  if platform.system() == 'Linux':
+    if filename.endswith('.so'):
+      return True
+    else:
+      index = filename.rfind('.so.')
+      if index == -1:
+        return False
+      else:
+        # A shared object with the API version in filename
+        return filename[index + 4].isdecimal()
+  elif platform.system() == 'Darwin':
+    return filename.endswith('.dylib')
+  elif platform.system() == 'Windows':
+    return filename.endswith('.dll')
+  else:
+    return False
+
+
+@tf_export('load_library')
+def load_library(library_location):
+  """Loads a TensorFlow plugin.
+
+  "library_location" can be a path to a specific shared object, or a folder.
+  If it is a folder, all sahred objects that are named "libtfkernel*" will be
+  loaded. When the library is loaded, kernels registered in the library via the
+  `REGISTER_*` macros are made available in the TensorFlow process.
+
+  Args:
+    library_location: Path to the plugin or the folder of plugins.
+      Relative or absolute filesystem path to a dynamic library file or folder.
+
+  Returns:
+    None
+
+  Raises:
+    OSError: When the file to be loaded is not found.
+    RuntimeError: when unable to load the library.
+  """
+  if file_io.file_exists(library_location):
+    if file_io.is_directory(library_location):
+      directory_contents = file_io.list_directory(library_location)
+
+      kernel_libraries = [
+          os.path.join(library_location, f) for f in directory_contents
+          if _is_shared_object(f)]
+    else:
+      kernel_libraries = [library_location]
+
+    for lib in kernel_libraries:
+      py_tf.TF_LoadLibrary(lib)
+
+  else:
+    raise OSError(
+        errno.ENOENT,
+        'The file or folder to load kernel libraries from does not exist.',
+        library_location)
+
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 343f52f..8bb1779 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -2532,8 +2532,8 @@
     output._shape_val = output._c_api_shape()
     # Set the resource handle data for compatibility with the Python shape
     # inference code.
-    serialized = c_api.GetResourceHandleShapeAndType(op._graph._c_graph,
-                                                     output._as_tf_output())
+    serialized = c_api.GetHandleShapeAndType(op._graph._c_graph,  # pylint: disable=protected-access
+                                             output._as_tf_output())
     if serialized:
       output._handle_data = (
           cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index d59adf3..c3a3437 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -2142,8 +2142,8 @@
 
     def function_with_variables():
       with ops.init_scope():
-        v = resource_variable_ops.ResourceVariable(3)
-      return v.assign_add(1)
+        self.v = resource_variable_ops.ResourceVariable(3)
+      return self.v.assign_add(1)
 
     with context.eager_mode():
       # Each invocation of function_with_variables recreates a variable.
@@ -2188,13 +2188,13 @@
 
     def inner_function():
       with ops.init_scope():
-        v = resource_variable_ops.ResourceVariable(1)
-      return v.assign_add(2)
+        self.v = resource_variable_ops.ResourceVariable(1)
+      return self.v.assign_add(2)
 
     def outer_function(inner=None):
       with ops.init_scope():
-        v0 = resource_variable_ops.ResourceVariable(0)
-      return v0.assign_add(1) + inner()
+        self.v0 = resource_variable_ops.ResourceVariable(0)
+      return self.v0.assign_add(1) + inner()
 
     with context.eager_mode():
       # Each invocation of outer_function recreates variables.
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index b739823..c302072 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -24,6 +24,7 @@
 import contextlib
 import gc
 import itertools
+import os
 import math
 import random
 import re
@@ -868,6 +869,19 @@
     yield
 
 
+class CapturedWrites(object):
+  """A utility class to load the captured writes made to a stream."""
+
+  def __init__(self, capture_location):
+    self.capture_location = capture_location
+
+  def contents(self):
+    """Get the captured writes as a single string."""
+    with open(self.capture_location) as tmp_file:
+      output_data = "".join(tmp_file.readlines())
+    return output_data
+
+
 class ErrorLoggingSession(session.Session):
   """Wrapper around a Session that logs errors in run().
   """
@@ -934,6 +948,52 @@
       self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
     return self._tempdir
 
+  @contextlib.contextmanager
+  def captureWritesToStream(self, stream):
+    """A context manager that captures the writes to a given stream.
+
+    This context manager captures all writes to a given stream inside of a
+    `CapturedWrites` object. When this context manager is created, it yields
+    the `CapturedWrites` object. The captured contents can be accessed  by
+    calling `.contents()` on the `CapturedWrites`.
+
+    For this function to work, the stream must have a file descriptor that
+    can be modified using `os.dup` and `os.dup2`, and the stream must support
+    a `.flush()` method. The default python sys.stdout and sys.stderr are
+    examples of this. Note that this does not work in Colab or Jupyter
+    notebooks, because those use alternate stdout streams.
+
+    Example:
+    ```python
+    class MyOperatorTest(test_util.TensorFlowTestCase):
+      def testMyOperator(self):
+        input = [1.0, 2.0, 3.0, 4.0, 5.0]
+        with self.captureWritesToStream(sys.stdout) as captured:
+          result = MyOperator(input).eval()
+        self.assertStartsWith(captured.contents(), "This was printed.")
+    ```
+
+    Args:
+      stream: The stream whose writes should be captured. This
+        stream must have a file descriptor, support writing via using that
+        file descriptor, and must have a `.flush()` method.
+
+    Yields:
+      A `CapturedWrites` object that contains all writes to the specified stream
+      made during this context.
+    """
+    stream.flush()
+    fd = stream.fileno()
+    tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir())
+    tmp_file = open(tmp_file_path, "w")
+    orig_fd = os.dup(fd)
+    os.dup2(tmp_file.fileno(), fd)
+    try:
+      yield CapturedWrites(tmp_file_path)
+    finally:
+      tmp_file.close()
+      os.dup2(orig_fd, fd)
+
   def _AssertProtoEquals(self, a, b, msg=None):
     """Asserts that a and b are the same proto.
 
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index b521b14..4a72c4b 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -381,12 +381,11 @@
     ],
 )
 
-py_test(
+cuda_py_test(
     name = "embeddings_test",
     size = "medium",
     srcs = ["layers/embeddings_test.py"],
-    srcs_version = "PY2AND3",
-    deps = [
+    additional_deps = [
         ":keras",
         "//tensorflow/python:client_testlib",
     ],
diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py
index a8b6d55..c35cdb1 100644
--- a/tensorflow/python/keras/applications/__init__.py
+++ b/tensorflow/python/keras/applications/__init__.py
@@ -63,7 +63,8 @@
   def wrapper(*args, **kwargs):
     if hasattr(keras_applications, 'get_submodules_from_kwargs'):
       kwargs['backend'] = backend
-      kwargs['layers'] = layers
+      if 'layers' not in kwargs:
+        kwargs['layers'] = layers
       kwargs['models'] = models
       kwargs['utils'] = utils
     return base_fun(*args, **kwargs)
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 529b07d..60ed8e8c 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -696,14 +696,14 @@
     return
   graph = v.graph if hasattr(v, 'graph') else ops.get_default_graph()
   if graph not in _GRAPH_VARIABLES:
-    _GRAPH_VARIABLES[graph] = set()
+    _GRAPH_VARIABLES[graph] = weakref.WeakSet()
   _GRAPH_VARIABLES[graph].add(v)
 
 
 def _get_variables(graph=None):
   """Returns variables corresponding to the given graph for initialization."""
   assert not context.executing_eagerly()
-  variables = _GRAPH_VARIABLES.get(graph, set())
+  variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet())
   for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
     variables.update(opt.optimizer.variables())
   return variables
@@ -3459,14 +3459,18 @@
   Returns:
       A tensor.
   """
-  clip_max = max_value is not None
 
   if alpha != 0.:
+    if max_value is None and threshold == 0:
+      return nn.leaky_relu(x, alpha=alpha)
+
     if threshold != 0:
       negative_part = nn.relu(-x + threshold)
     else:
       negative_part = nn.relu(-x)
 
+  clip_max = max_value is not None
+
   if threshold != 0:
     # computes x for x > threshold else 0
     x = x * math_ops.cast(math_ops.greater(x, threshold), floatx())
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 2f271c4..ab71589 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -522,8 +522,9 @@
       relu_op = keras.backend.relu(x)
       self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 7]])
 
-      # alpha
+      # alpha (leaky relu used)
       relu_op = keras.backend.relu(x, alpha=0.5)
+      self.assertTrue('LeakyRelu' in relu_op.name)
       self.assertAllClose(keras.backend.eval(relu_op), [[-2, 0], [2, 7]])
 
       # max_value < some elements
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index befe82f..6dfbbf3 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -360,7 +360,10 @@
   def on_batch_end(self, batch, logs=None):
     logs = logs or {}
     batch_size = logs.get('size', 0)
-    self.seen += batch_size
+    # In case of distribution strategy we can potentially run multiple steps
+    # at the same time, we should account for that in the `seen` calculation.
+    num_steps = logs.get('num_steps', 1)
+    self.seen += batch_size * num_steps
 
     for k, v in logs.items():
       if k in self.stateful_metrics:
@@ -448,10 +451,13 @@
   def on_batch_end(self, batch, logs=None):
     logs = logs or {}
     batch_size = logs.get('size', 0)
+    # In case of distribution strategy we can potentially run multiple steps
+    # at the same time, we should account for that in the `seen` calculation.
+    num_steps = logs.get('num_steps', 1)
     if self.use_steps:
-      self.seen += 1
+      self.seen += num_steps
     else:
-      self.seen += batch_size
+      self.seen += batch_size * num_steps
 
     for k in self.params['metrics']:
       if k in logs:
@@ -1068,7 +1074,7 @@
     logs = logs or {}
     batch_logs = {('batch_' + k): v
                   for k, v in logs.items()
-                  if k not in ['batch', 'size']}
+                  if k not in ['batch', 'size', 'num_steps']}
     self._write_custom_summaries(self._total_batches_seen, batch_logs)
     self._total_batches_seen += 1
 
@@ -1092,7 +1098,7 @@
     # batch number as Tensorboard summaries
     logs = {('epoch_' + k): v
             for k, v in logs.items()
-            if k not in ['batch', 'size']}
+            if k not in ['batch', 'size', 'num_steps']}
     self._write_custom_summaries(epoch, logs)
 
     # pop the histogram summary op after each epoch
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index 148dd23..02d99d5 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -370,6 +370,13 @@
       y = np.random.random((1, 3, 3))
       model.train_on_batch(x, y)
       new_model.train_on_batch(x, y)
+
+      x = np.random.random((1, 3))
+      y = np.random.random((1, 3, 3))
+      eval_out = model.evaluate(x, y)
+      eval_out2 = new_model.evaluate(x, y)
+      self.assertArrayNear(eval_out, eval_out2, 0.001)
+
       out = model.predict(x)
       out2 = new_model.predict(x)
       self.assertAllClose(out, out2, atol=1e-05)
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index c674946..154c219 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -145,32 +145,34 @@
         if i not in skip_target_weighing_indices
     ]
 
-  def _get_metric_name(self, metric, output_index, weighted=False):
-    """Returns the metric name corresponding to the given metric input.
+  def _cache_output_metric_attributes(self, metrics, weighted_metrics):
+    """Caches metric name and function attributes for every model output."""
+    output_shapes = [
+        None if output is None else output.get_shape().as_list()
+        for output in self.outputs
+    ]
+    self._per_output_metrics = training_utils.collect_per_output_metric_info(
+        metrics, self.output_names, output_shapes, self.loss_functions)
+    self._per_output_weighted_metrics = \
+        training_utils.collect_per_output_metric_info(
+            weighted_metrics, self.output_names, output_shapes,
+            self.loss_functions, self.sample_weights)
+
+  def _add_unique_metric_name(self, metric_name, output_index):
+    """Makes the metric name unique and adds it to the model's metric name list.
+
+      If there are multiple outputs for which the metrics are calculated, the
+      metric names have to be made unique by appending an integer.
 
     Arguments:
-        metric: Metric function name or reference.
-      output_index: Index of the current output.
-        weighted: Boolean indicating if the given metric is weighted.
+      metric_name: Metric name that corresponds to the metric specified by the
+          user. For example: 'acc'.
+      output_index: The index of the model output for which the metric name is
+        being added.
 
     Returns:
-        A metric name.
+      string, name of the model's unique metric name
     """
-    metric_name_prefix = 'weighted_' if weighted else ''
-    if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
-      if metric in ('accuracy', 'acc'):
-        suffix = 'acc'
-      elif metric in ('crossentropy', 'ce'):
-        suffix = 'ce'
-    else:
-      metric_fn = metrics_module.get(metric)
-      # Get metric name as string
-      if hasattr(metric_fn, 'name'):
-        suffix = metric_fn.name
-      else:
-        suffix = metric_fn.__name__
-    metric_name = metric_name_prefix + suffix
-
     if len(self.output_names) > 1:
       metric_name = '%s_%s' % (self.output_names[output_index], metric_name)
     j = 1
@@ -181,75 +183,24 @@
 
     return metric_name
 
-  def _handle_per_output_metrics(self,
-                                 metrics,
-                                 y_true,
-                                 y_pred,
-                                 output_index,
-                                 output_shape,
-                                 loss_fn,
-                                 mask,
-                                 weights=None):
-    """Calls metric functions and sets metric attributes for a single output.
+  def _init_metric_attributes(self):
+    """Initialized model metric attributes."""
+    self.metrics_names = ['loss']
+    self.metrics_tensors = []
+    self.metrics_updates = []
+    self.stateful_metric_names = []
+    self.stateful_metric_functions = []
+
+  def _set_per_output_metric_attributes(self, metrics_dict, output_index):
+    """Sets the metric attributes on the model for the given output.
 
     Arguments:
-      metrics: List of metrics.
-      y_true: Target output.
-      y_pred: Predicted output.
-      output_index: Index of the current output.
-      output_shape: Shape of the current output.
-      loss_fn: Loss function corresponding to the current output.
-      mask: Computed mask value for the current output.
-      weights: Weights to be applied on the current output.
-
-    Returns:
-      A list of metric result tensors.
+      metrics_dict: A dict with metric names as keys and metric fns as values.
+      output_index: The index of the model output for which the metric
+        attributes are added.
     """
-    metric_results = []
-    for metric in metrics:
-      metric_fn = training_utils.get_metric_function(
-          metric, output_shape=output_shape, loss_fn=loss_fn)
-      metric_name = self._get_metric_name(
-          metric, output_index, weighted=weights is not None)
-
-      with K.name_scope(metric_name):
-        # If both outputs and targets are available, call the metric function.
-        if y_true is not None and y_pred is not None:
-          if isinstance(metric_fn, metrics_module.Metric):
-            # Call the stateful metric function.
-            if mask is not None:
-              mask = math_ops.cast(mask, y_pred.dtype)
-              # Update weights with mask.
-              if weights is None:
-                weights = mask
-              else:
-                # Update shape of weights if possible before adding mask.
-                # Update dimensions of weights to match with mask if possible.
-                mask, _, weights = metrics_module.squeeze_or_expand_dimensions(
-                    mask, None, weights)
-                try:
-                  # Broadcast weights if possible.
-                  weights = weights_broadcast_ops.broadcast_weights(
-                      weights, mask)
-                except ValueError:
-                  pass
-                  # TODO(psv): Handle case when mask and weight shapes are not
-                  # compatible.
-                weights *= mask
-
-            metric_result = metric_fn(y_true, y_pred, weights)
-          else:
-            # Call the stateless metric function.
-            weighted_metric_fn = training_utils.weighted_masked_objective(
-                metric_fn)
-            metric_result = weighted_metric_fn(
-                y_true, y_pred, weights=weights, mask=mask)
-
-          if not context.executing_eagerly():
-            # Keep track of metric result tensor.
-            self.metrics_tensors.append(metric_result)
-          metric_results.append(metric_result)
-
+    for metric_name, metric_fn in metrics_dict.items():
+      metric_name = self._add_unique_metric_name(metric_name, output_index)
       # Keep track of metric name.
       self.metrics_names.append(metric_name)
 
@@ -257,9 +208,77 @@
       if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
         self.stateful_metric_names.append(metric_name)
         self.stateful_metric_functions.append(metric_fn)
+
+  def _set_metric_attributes(self, outputs, skip_target_indices=None):
+    """Sets the metric attributes on the model for all the model outputs."""
+    skip_target_indices = skip_target_indices or []
+    for i in range(len(outputs)):
+      if i in skip_target_indices:
+        continue
+      self._set_per_output_metric_attributes(self._per_output_metrics[i], i)
+      self._set_per_output_metric_attributes(
+          self._per_output_weighted_metrics[i], i)
+
+  def _handle_per_output_metrics(self,
+                                 metrics_dict,
+                                 y_true,
+                                 y_pred,
+                                 mask,
+                                 weights=None):
+    """Calls metric functions for a single output.
+
+    Arguments:
+      metrics_dict: A dict with metric names as keys and metric fns as values.
+      y_true: Target output.
+      y_pred: Predicted output.
+      mask: Computed mask value for the current output.
+      weights: Weights to be applied on the current output.
+
+    Returns:
+      A list of metric result tensors.
+    """
+    metric_results = []
+    for metric_name, metric_fn in metrics_dict.items():
+      with K.name_scope(metric_name):
+        if isinstance(metric_fn, metrics_module.Metric):
+          # Call the stateful metric function.
+          if mask is not None:
+            mask = math_ops.cast(mask, y_pred.dtype)
+            # Update weights with mask.
+            if weights is None:
+              weights = mask
+            else:
+              # Update shape of weights if possible before adding mask.
+              # Update dimensions of weights to match with mask if possible.
+              mask, _, weights = metrics_module.squeeze_or_expand_dimensions(
+                  mask, None, weights)
+              try:
+                # Broadcast weights if possible.
+                weights = weights_broadcast_ops.broadcast_weights(weights, mask)
+              except ValueError:
+                pass
+                # TODO(psv): Handle case when mask and weight shapes are not
+                # compatible.
+              weights *= mask
+
+          metric_result = metric_fn(y_true, y_pred, weights)
+        else:
+          # Call the stateless metric function.
+          weighted_metric_fn = training_utils.weighted_masked_objective(
+              metric_fn)
+          metric_result = weighted_metric_fn(
+              y_true, y_pred, weights=weights, mask=mask)
+
         if not context.executing_eagerly():
-          # Keep track of updates created by stateful metrics.
-          self.metrics_updates += metric_fn.updates
+          # Keep track of metric result tensor.
+          self.metrics_tensors.append(metric_result)
+
+      metric_results.append(metric_result)
+      is_stateful = isinstance(metric_fn,
+                               base_layer.Layer) and metric_fn.stateful
+      if is_stateful and not context.executing_eagerly():
+        # Keep track of updates created by stateful metrics.
+        self.metrics_updates += metric_fn.updates
     return metric_results
 
   def _handle_metrics(self,
@@ -268,7 +287,7 @@
                       targets=None,
                       sample_weights=None,
                       masks=None):
-    """Handles calling metric functions and setting model metric attributes.
+    """Handles calling metric functions.
 
     Arguments:
       outputs: List of outputs (predictions).
@@ -288,20 +307,15 @@
           continue
         output = outputs[i] if outputs else None
         target = targets[i] if targets else None
-        output_shape = None if output is None else output.get_shape().as_list()
         output_mask = masks[i] if masks else None
         metric_results.extend(
-            self._handle_per_output_metrics(
-                self.nested_metrics[i], target, output, i, output_shape,
-                self.loss_functions[i], output_mask))
+            self._handle_per_output_metrics(self._per_output_metrics[i], target,
+                                            output, output_mask))
         metric_results.extend(
             self._handle_per_output_metrics(
-                self.nested_weighted_metrics[i],
+                self._per_output_weighted_metrics[i],
                 target,
                 output,
-                i,
-                output_shape,
-                self.loss_functions[i],
                 output_mask,
                 weights=sample_weights[i]))
     return metric_results
@@ -369,27 +383,31 @@
     """
     # Validate that arguments passed by the user to `compile` are supported by
     # DistributionStrategy.
-    if distribute and not isinstance(
-        optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
-      raise NotImplementedError('Only TF native optimizers are supported with '
-                                'DistributionStrategy.')
-    if distribute and context.executing_eagerly():
-      raise NotImplementedError('DistributionStrategy is not supported in '
-                                'Eager mode.')
-    if distribute and sample_weight_mode:
-      raise NotImplementedError('sample_weight_mode is not supported with '
-                                'DistributionStrategy.')
-    if distribute and weighted_metrics:
-      raise NotImplementedError('weighted_metrics is not supported with '
-                                'DistributionStrategy.')
-    if distribute and target_tensors:
-      raise ValueError('target_tensors is not supported with '
-                       'DistributionStrategy.')
+    if distribute:
+      if not isinstance(
+          optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
+        raise NotImplementedError(
+            'optimizer must be an instance of '
+            'tf.train.Optimizer, not a %s' % type(optimizer))
+      if context.executing_eagerly():
+        raise NotImplementedError('DistributionStrategy is not supported '
+                                  'when eager execution is enabled.')
+      if sample_weight_mode:
+        raise NotImplementedError('sample_weight_mode is not supported with '
+                                  'DistributionStrategy.')
+      if weighted_metrics:
+        raise NotImplementedError('weighted_metrics is not supported with '
+                                  'DistributionStrategy.')
+      if target_tensors:
+        raise ValueError('target_tensors is not supported with '
+                         'DistributionStrategy.')
 
     loss = loss or {}
     if context.executing_eagerly() and not isinstance(
         optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
-      raise ValueError('Only TF native optimizers are supported in Eager mode.')
+      raise ValueError(
+          'optimizer must be an instance of tf.train.Optimizer, not '
+          'a %s' % type(optimizer))
 
     self.optimizer = optimizers.get(optimizer)
     # We've disabled automatic dependency tracking for this method, but do want
@@ -408,8 +426,9 @@
 
     # Set DistributionStrategy specific parameters.
     self._distribution_strategy = distribute
+    # Reset the value of grouped_model
+    self._grouped_model = None
     if self._distribution_strategy is not None:
-      self._grouped_model = None
       distributed_training_utils.configure_and_create_session(
           self._distribution_strategy)
     if not self.built:
@@ -431,7 +450,8 @@
       for name in self.output_names:
         if name not in loss:
           logging.warning(
-              'Output "' + name + '" missing from loss dictionary. We assume '
+              'Output "' + name +
+              '" missing from loss dictionary. We assume '
               'this was done on purpose. The fit and evaluate APIs will not be '
               'expecting any data to be passed to "' + name + '".')
         loss_functions.append(losses.get(loss.get(name)))
@@ -493,24 +513,15 @@
     self.loss_weights_list = loss_weights_list
 
     # Initialize model metric attributes.
-    self.metrics_names = ['loss']
-    self.metrics_tensors = []
-    self.metrics_updates = []
-    self.stateful_metric_names = []
-    self.stateful_metric_functions = []
-
-    # Nested metrics is a list of list of metrics.
-    # One list per output of the model.
-    self.nested_metrics = training_utils.collect_metrics(
-        metrics, self.output_names)
-    self.nested_weighted_metrics = training_utils.collect_metrics(
-        weighted_metrics, self.output_names)
+    self._init_metric_attributes()
 
     # Initialization for Eager mode execution.
     if context.executing_eagerly():
       # Prepare sample weights.
       self._set_sample_weight_attributes(sample_weight_mode,
                                          skip_target_weighing_indices)
+      # Save all metric attributes per output of the model.
+      self._cache_output_metric_attributes(metrics, weighted_metrics)
 
       if target_tensors is not None:
         raise ValueError('target_tensors are not currently supported in Eager '
@@ -521,10 +532,10 @@
           self.metrics_names.append(self.output_names[i] + '_loss')
 
       # Set metric attributes on model.
-      self._handle_metrics(
+      self._set_metric_attributes(
           self.outputs,
           skip_target_indices=skip_target_indices,
-          sample_weights=self.sample_weights)
+      )
 
       self.targets = []
       for i in range(len(self.outputs)):
@@ -587,6 +598,8 @@
     # Prepare sample weights.
     self._set_sample_weight_attributes(sample_weight_mode,
                                        skip_target_weighing_indices)
+    # Save all metric attributes per output of the model.
+    self._cache_output_metric_attributes(metrics, weighted_metrics)
 
     # Compute total loss.
     total_loss = None
@@ -621,6 +634,11 @@
       for loss_tensor in self.losses:
         total_loss += loss_tensor
 
+    # Set metric attributes on model.
+    self._set_metric_attributes(
+        self.outputs,
+        skip_target_indices=skip_target_indices,
+    )
     # Invoke metric functions for all the outputs.
     self._handle_metrics(
         self.outputs,
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 53291c3..26c5ec4 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -20,6 +20,7 @@
 from __future__ import print_function
 import numpy as np
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import errors
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import callbacks as cbks
@@ -292,11 +293,16 @@
   for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
     initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
 
+  if steps_per_epoch is None:
+    raise ValueError('steps_per_epoch should be specified in the fit call.')
+  steps_per_run_var = K.variable(
+      value=min(steps_per_epoch, current_strategy.steps_per_run),
+      dtype='int32',
+      name='steps_per_run_var')
+
   with current_strategy.scope():
-    # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
-    # steps_per_epoch and number of epochs.
     ctx = current_strategy.run_steps_on_dataset(
-        step_fn, iterator, iterations=current_strategy.steps_per_run,
+        step_fn, iterator, iterations=steps_per_run_var,
         initial_loop_values=initial_loop_values)
 
   train_op = ctx.run_op
@@ -308,14 +314,6 @@
     distributed_model = current_strategy.unwrap(model._grouped_model)[0]
     distributed_training_utils.set_weights(
         current_strategy, distributed_model, orig_model_weights)
-
-  assert steps_per_epoch is not None
-
-  # TODO(sourabhbajaj): Convert this into a proper validation function
-  if callbacks:
-    raise NotImplementedError(
-        'Callbacks are not supported with TPUStrategy right now.')
-
   callbacks = cbks.configure_callbacks(
       callbacks,
       model,
@@ -326,17 +324,26 @@
       steps_per_epoch=steps_per_epoch,
       verbose=verbose)
   # TODO(priyag, sourabhbajaj): Add callbacks support for per step callback
-  # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run
   # TODO(priyag, sourabhbajaj): Add validation.
+
+  # Calculate the steps each time on the device.
+  steps_to_run = [current_strategy.steps_per_run] * (
+      steps_per_epoch // current_strategy.steps_per_run)
+  if steps_per_epoch % current_strategy.steps_per_run:
+    steps_to_run.append(steps_per_epoch % current_strategy.steps_per_run)
+
   callbacks.on_train_begin()
   for epoch in range(initial_epoch, epochs):
     callbacks.on_epoch_begin(epoch)
     epoch_logs = {}
-    for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run):
-      # TODO(sourabhbajaj): Replace size with a combination of steps_per_run
-      # and batch_size
-      batch_logs = {'batch': step_index, 'size': 1}
+    step_index = 0
+    prev_step_count = None
+    for step_count in steps_to_run:
+      batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
       callbacks.on_batch_begin(step_index, batch_logs)
+      if prev_step_count is None or step_count != prev_step_count:
+        steps_per_run_var.load(step_count, K.get_session())
+        prev_step_count = step_count
       try:
         _, outputs = K.get_session().run([train_op, output_tensors])
       except errors.OutOfRangeError:
@@ -349,6 +356,7 @@
 
       batch_logs.update(outputs)
       callbacks.on_batch_end(step_index, batch_logs)
+      step_index = step_index + step_count
       if callbacks.model.stop_training:
         break
 
@@ -742,8 +750,9 @@
   for name, tensor in zip(model.output_names, model.outputs):
     # TODO(priyag): This is a workaround as we do not know the batch dimension
     # of the model's output at this point.
-    tensor.shape.dims = [batch_dimension] + tensor.shape.dims[1:]
-    initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+    shape = tensor_shape.TensorShape(tensor.shape.dims)
+    shape.dims = [batch_dimension] + shape.dims[1:]
+    initial_loop_values[name] = array_ops.zeros(shape, tensor.dtype)
 
   with current_strategy.scope():
     # TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed.
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index 939a7f2..fb71bf2 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -186,7 +186,7 @@
   # make sure either x,y or x,y,sample_weights is provided
   if (not isinstance(inputs.output_shapes, (list, tuple)) or
       len(inputs.output_shapes) not in (2, 3)):
-    raise ValueError('Please provide either inputs and targets'
+    raise ValueError('Please provide either inputs and targets '
                      'or inputs, targets, and sample_weights')
 
   for step_index in range(steps_per_epoch):
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 3801300..30be413 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -2256,7 +2256,26 @@
         'dense_binary_accuracy', 'dropout_mean_squared_error',
         'dropout_binary_accuracy'
     ]
+    reference_stateful_metric_names = [
+        'dense_binary_accuracy', 'dropout_binary_accuracy'
+    ]
     self.assertEqual(reference_metric_names, model.metrics_names)
+    self.assertEqual(reference_stateful_metric_names,
+                     model.stateful_metric_names)
+
+    # Verify that model metric names are not altered during training.
+    input_a_np = np.random.random((10, 3))
+    input_b_np = np.random.random((10, 3))
+
+    output_d_np = np.random.random((10, 4))
+    output_e_np = np.random.random((10, 4))
+
+    model.fit([input_a_np, input_b_np], [output_d_np, output_e_np],
+              epochs=1,
+              batch_size=5)
+    self.assertEqual(reference_metric_names, model.metrics_names)
+    self.assertEqual(reference_stateful_metric_names,
+                     model.stateful_metric_names)
 
   @tf_test_util.run_in_graph_and_eager_modes
   def test_metrics_correctness(self):
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 8e9fab8..9c303f4 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -18,6 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
+from collections import OrderedDict
 import copy
 import math
 
@@ -484,29 +485,36 @@
                            'as the output.')
 
 
-def collect_metrics(metrics, output_names):
-  """Maps metric functions to model outputs.
+def collect_per_output_metric_info(metrics,
+                                   output_names,
+                                   output_shapes,
+                                   loss_fns,
+                                   sample_weights=None):
+  """Maps metric names and functions to model outputs.
 
   Arguments:
       metrics: a list or dict of metric functions.
       output_names: a list of the names (strings) of model outputs.
+      output_shapes: a list of the shapes (strings) of model outputs.
+      loss_fns: a list of the loss functions corresponding to the model outputs.
+      sample_weights: a list of weights to be applied on the model outputs.
 
   Returns:
-      A list (one entry per model output) of lists of metric functions.
+      A list (one entry per model output) of dicts.
       For instance, if the model has 2 outputs, and for the first output
       we want to compute "binary_accuracy" and "binary_crossentropy",
       and just "binary_accuracy" for the second output,
-      the list would look like:
-          `[[binary_accuracy, binary_crossentropy], [binary_accuracy]]`
+      the list would look like: `[[('acc', binary_accuracy()),
+      ('ce', binary_crossentropy())], [('acc', binary_accuracy())]]`
 
   Raises:
       TypeError: if an incorrect type is passed for the `metrics` argument.
   """
   if not metrics:
-    return [[] for _ in output_names]
+    return [{} for _ in output_names]
   if isinstance(metrics, list):
     # we then apply all metrics to all outputs.
-    return [copy.copy(metrics) for _ in output_names]
+    nested_metrics = [copy.copy(metrics) for _ in output_names]
   elif isinstance(metrics, dict):
     nested_metrics = []
     for name in output_names:
@@ -514,11 +522,24 @@
       if not isinstance(output_metrics, list):
         output_metrics = [output_metrics]
       nested_metrics.append(output_metrics)
-    return nested_metrics
   else:
     raise TypeError('Type of `metrics` argument not understood. '
                     'Expected a list or dictionary, found: ' + str(metrics))
 
+  per_output_metrics = []
+  for i, metrics in enumerate(nested_metrics):
+    metrics_dict = OrderedDict()
+    for metric in metrics:
+      weighted = False if (sample_weights is None) else (
+          sample_weights[i] is not None)
+      metric_name = get_metric_name(metric, weighted)
+      metric_fn = get_metric_function(
+          metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])
+      metrics_dict[metric_name] = metric_fn
+    per_output_metrics.append(metrics_dict)
+
+  return per_output_metrics
+
 
 def batch_shuffle(index_array, batch_size):
   """Shuffles an array in a batch-wise fashion.
@@ -729,6 +750,33 @@
   return tensor_util.is_tensor(ls)
 
 
+def get_metric_name(metric, weighted=False):
+  """Returns the name corresponding to the given metric input.
+
+  Arguments:
+    metric: Metric function name or reference.
+    weighted: Boolean indicating if the given metric is weighted.
+
+  Returns:
+      The metric name.
+  """
+  metric_name_prefix = 'weighted_' if weighted else ''
+  if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
+    if metric in ('accuracy', 'acc'):
+      suffix = 'acc'
+    elif metric in ('crossentropy', 'ce'):
+      suffix = 'ce'
+  else:
+    metric_fn = metrics_module.get(metric)
+    # Get metric name as string
+    if hasattr(metric_fn, 'name'):
+      suffix = metric_fn.name
+    else:
+      suffix = metric_fn.__name__
+  metric_name = metric_name_prefix + suffix
+  return metric_name
+
+
 def get_metric_function(metric, output_shape=None, loss_fn=None):
   """Returns the metric function corresponding to the given metric input.
 
diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py
index 61ab69c..a2385df 100644
--- a/tensorflow/python/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/layers/advanced_activations.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.keras import activations
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import constraints
 from tensorflow.python.keras import initializers
@@ -268,7 +267,7 @@
     self.axis = axis
 
   def call(self, inputs):
-    return activations.softmax(inputs, axis=self.axis)
+    return K.softmax(inputs, axis=self.axis)
 
   def get_config(self):
     config = {'axis': self.axis}
@@ -315,18 +314,19 @@
                        'cannot be negative value: ' + str(negative_slope))
 
     self.support_masking = True
-    self.max_value = K.cast_to_floatx(max_value)
+    if max_value is not None:
+      max_value = K.cast_to_floatx(max_value)
+    self.max_value = max_value
     self.negative_slope = K.cast_to_floatx(negative_slope)
     self.threshold = K.cast_to_floatx(threshold)
 
   def call(self, inputs):
     # alpha is used for leaky relu slope in activations instead of
     # negative_slope.
-    return activations.relu(
-        inputs,
-        alpha=self.negative_slope,
-        max_value=self.max_value,
-        threshold=self.threshold)
+    return K.relu(inputs,
+                  alpha=self.negative_slope,
+                  max_value=self.max_value,
+                  threshold=self.threshold)
 
   def get_config(self):
     config = {
diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py
index b020b6e..c41087b 100644
--- a/tensorflow/python/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/layers/advanced_activations_test.py
@@ -67,6 +67,14 @@
       testing_utils.layer_test(keras.layers.ReLU,
                                kwargs={'max_value': 10},
                                input_shape=(2, 3, 4))
+      x = keras.backend.ones((3, 4))
+      # Test that we use `leaky_relu` when appropriate in graph mode.
+      self.assertTrue(
+          'LeakyRelu' in keras.layers.ReLU(negative_slope=0.2)(x).name)
+      # Test that we use `relu` when appropriate in graph mode.
+      self.assertTrue('Relu' in keras.layers.ReLU()(x).name)
+      # Test that we use `relu6` when appropriate in graph mode.
+      self.assertTrue('Relu6' in keras.layers.ReLU(max_value=6)(x).name)
 
   def test_relu_with_invalid_arg(self):
     with self.assertRaisesRegexp(
diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py
index 629a9ec..c6df5f2 100644
--- a/tensorflow/python/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/layers/embeddings.py
@@ -18,6 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import constraints
 from tensorflow.python.keras import initializers
@@ -117,12 +119,27 @@
 
   @tf_utils.shape_type_conversion
   def build(self, input_shape):
-    self.embeddings = self.add_weight(
-        shape=(self.input_dim, self.output_dim),
-        initializer=self.embeddings_initializer,
-        name='embeddings',
-        regularizer=self.embeddings_regularizer,
-        constraint=self.embeddings_constraint)
+    # Note: most sparse optimizers do not have GPU kernels defined. When
+    # building graphs, the placement algorithm is able to place variables on CPU
+    # since it knows all kernels using the variable only exist on CPU.
+    # When eager execution is enabled, the placement decision has to be made
+    # right now. Checking for the presence of GPUs to avoid complicating the
+    # TPU codepaths which can handle sparse optimizers.
+    if context.executing_eagerly() and context.context().num_gpus():
+      with ops.device('cpu:0'):
+        self.embeddings = self.add_weight(
+            shape=(self.input_dim, self.output_dim),
+            initializer=self.embeddings_initializer,
+            name='embeddings',
+            regularizer=self.embeddings_regularizer,
+            constraint=self.embeddings_constraint)
+    else:
+      self.embeddings = self.add_weight(
+          shape=(self.input_dim, self.output_dim),
+          initializer=self.embeddings_initializer,
+          name='embeddings',
+          regularizer=self.embeddings_regularizer,
+          constraint=self.embeddings_constraint)
     self.built = True
 
   def compute_mask(self, inputs, mask=None):
diff --git a/tensorflow/python/keras/layers/embeddings_test.py b/tensorflow/python/keras/layers/embeddings_test.py
index cab176e..2e42e40 100644
--- a/tensorflow/python/keras/layers/embeddings_test.py
+++ b/tensorflow/python/keras/layers/embeddings_test.py
@@ -21,9 +21,11 @@
 import numpy as np
 
 from tensorflow.python import keras
+from tensorflow.python.eager import backprop
 from tensorflow.python.framework import test_util as tf_test_util
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.platform import test
+from tensorflow.python.training import adagrad
 
 
 class EmbeddingTest(test.TestCase):
@@ -78,6 +80,17 @@
       outputs = keras.backend.eval(layer(inputs))
       self.assertAllClose(outputs, [[[1, 1], [2, 2], [1, 1]]])
 
+  @tf_test_util.run_in_graph_and_eager_modes()
+  def test_eager_gpu_cpu(self):
+    l = keras.layers.Embedding(output_dim=2, input_dim=2)
+    l.build((None, 2))
+    inputs = keras.backend.constant([[0, 1, 0]], dtype='int32')
+    with backprop.GradientTape() as tape:
+      output = l(inputs)
+    gs = tape.gradient(output, l.weights)
+    opt = adagrad.AdagradOptimizer(0.1)
+    opt.apply_gradients(zip(gs, l.weights))
+    self.assertAllEqual(len(gs), 1)
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 473d8cd..e64241e 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -199,7 +199,6 @@
     # squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
     y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
         y_true, y_pred)
-    y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
 
   if sample_weight is None:
     return y_pred, y_true, None
@@ -342,19 +341,14 @@
       # weak reference. This is to remove reference cycle that is created here.
       # This is not an issue in python versions > 3.
       if context.executing_eagerly():
-        update_state = weakmethod(obj.update_state)
-      else:
-        update_state = function.defun(obj.update_state)
+        obj.update_state = weakmethod(obj.update_state)
       obj.update_state = weakmethod(
-          types.MethodType(update_state_wrapper(update_state), obj))
+          types.MethodType(update_state_wrapper(obj.update_state), obj))
       result = weakmethod(obj.result)
       obj.result = weakmethod(types.MethodType(result_wrapper(result), obj))
     else:
-      # Converting update_state_fn() into a graph function, so that
-      # we can return a single op that performs all of the variable updates.
-      defuned_update_state_fn = function.defun(obj.update_state)
       obj.update_state = types.MethodType(
-          update_state_wrapper(defuned_update_state_fn), obj)
+          update_state_wrapper(obj.update_state), obj)
       obj.result = types.MethodType(result_wrapper(obj.result), obj)
 
     return obj
@@ -475,6 +469,9 @@
     Args:
       values: Per-example value.
       sample_weight: Optional weighting of each example. Defaults to 1.
+
+    Returns:
+      Update op.
     """
     values = math_ops.cast(values, self._dtype)
     if sample_weight is None:
@@ -501,8 +498,9 @@
     values = math_ops.reduce_sum(values)
 
     # Update state variables
-    state_ops.assign_add(self.total, values)
-    state_ops.assign_add(self.count, num_values)
+    update_total_op = state_ops.assign_add(self.total, values)
+    update_count_op = state_ops.assign_add(self.count, num_values)
+    return control_flow_ops.group(update_total_op, update_count_op)
 
   def result(self):
     return safe_div(self.total, self.count)
@@ -536,6 +534,9 @@
       sample_weight: Optional weighting of each example. Defaults to 1. Can be
         a `Tensor` whose rank is either 0, or the same rank as `y_true`,
         and must be broadcastable to `y_true`.
+
+    Returns:
+      Update op.
     """
     y_true = math_ops.cast(y_true, self._dtype)
     y_pred = math_ops.cast(y_pred, self._dtype)
@@ -543,7 +544,7 @@
         y_pred, y_true, sample_weight)
 
     matches = self._fn(y_true, y_pred, **self._fn_kwargs)
-    super(MeanMetricWrapper, self).update_state(
+    return super(MeanMetricWrapper, self).update_state(
         matches, sample_weight=sample_weight)
 
   def get_config(self):
@@ -600,6 +601,23 @@
         categorical_accuracy, name, dtype=dtype)
 
 
+class SparseCategoricalAccuracy(MeanMetricWrapper):
+  """Calculates how often predictions matches integer labels.
+
+  This metric creates two local variables, `total` and `count` that are used to
+  compute the frequency with which `y_pred` matches `y_true`. This frequency is
+  ultimately returned as `sparse categorical accuracy`: an idempotent operation
+  that simply divides `total` by `count`.
+
+  If `sample_weight` is `None`, weights default to 1.
+  Use `sample_weight` of 0 to mask values.
+  """
+
+  def __init__(self, name='sparse_categorical_accuracy', dtype=None):
+    super(SparseCategoricalAccuracy, self).__init__(
+        sparse_categorical_accuracy, name, dtype=dtype)
+
+
 @tf_export('keras.metrics.binary_accuracy')
 def binary_accuracy(y_true, y_pred, threshold=0.5):
   threshold = math_ops.cast(threshold, y_pred.dtype)
@@ -615,6 +633,7 @@
       K.floatx())
 
 
+@tf_export('keras.metrics.sparse_categorical_accuracy')
 def sparse_categorical_accuracy(y_true, y_pred):
   y_true = math_ops.reduce_max(y_true, axis=-1)
   y_pred = math_ops.argmax(y_pred, axis=-1)
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
index c7e9499..d6016ed 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -48,7 +48,7 @@
     if not check_if_compatible_devices(gpus=gpus):
       return
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(hidden_dim,
                                    input_shape=(input_dim,)))
@@ -78,7 +78,7 @@
     if not check_if_compatible_devices(gpus=gpus):
       return
 
-    with self.test_session():
+    with self.cached_session():
       input_a = keras.Input((input_dim_a,))
       input_b = keras.Input((input_dim_b,))
       a = keras.layers.Dense(hidden_dim)(input_a)
@@ -105,7 +105,7 @@
     if not check_if_compatible_devices(gpus=2):
       return
 
-    with self.test_session():
+    with self.cached_session():
       input_shape = (1000, 10)
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(10,
@@ -144,7 +144,7 @@
     if not check_if_compatible_devices(gpus=gpus):
       return
 
-    with self.test_session():
+    with self.cached_session():
       input_shape = (num_samples,) + shape
       x_train = np.random.randint(0, 255, input_shape)
       y_train = np.random.randint(0, num_classes, (input_shape[0],))
diff --git a/tensorflow/python/keras/wrappers/scikit_learn_test.py b/tensorflow/python/keras/wrappers/scikit_learn_test.py
index c322efd..f904290 100644
--- a/tensorflow/python/keras/wrappers/scikit_learn_test.py
+++ b/tensorflow/python/keras/wrappers/scikit_learn_test.py
@@ -102,7 +102,7 @@
 class ScikitLearnAPIWrapperTest(test.TestCase):
 
   def test_classify_build_fn(self):
-    with self.test_session():
+    with self.cached_session():
       clf = keras.wrappers.scikit_learn.KerasClassifier(
           build_fn=build_fn_clf,
           hidden_dim=HIDDEN_DIM,
@@ -118,7 +118,7 @@
       def __call__(self, hidden_dim):
         return build_fn_clf(hidden_dim)
 
-    with self.test_session():
+    with self.cached_session():
       clf = keras.wrappers.scikit_learn.KerasClassifier(
           build_fn=ClassBuildFnClf(),
           hidden_dim=HIDDEN_DIM,
@@ -134,7 +134,7 @@
       def __call__(self, hidden_dim):
         return build_fn_clf(hidden_dim)
 
-    with self.test_session():
+    with self.cached_session():
       clf = InheritClassBuildFnClf(
           build_fn=None,
           hidden_dim=HIDDEN_DIM,
@@ -144,7 +144,7 @@
       assert_classification_works(clf)
 
   def test_regression_build_fn(self):
-    with self.test_session():
+    with self.cached_session():
       reg = keras.wrappers.scikit_learn.KerasRegressor(
           build_fn=build_fn_reg,
           hidden_dim=HIDDEN_DIM,
@@ -160,7 +160,7 @@
       def __call__(self, hidden_dim):
         return build_fn_reg(hidden_dim)
 
-    with self.test_session():
+    with self.cached_session():
       reg = keras.wrappers.scikit_learn.KerasRegressor(
           build_fn=ClassBuildFnReg(),
           hidden_dim=HIDDEN_DIM,
@@ -176,7 +176,7 @@
       def __call__(self, hidden_dim):
         return build_fn_reg(hidden_dim)
 
-    with self.test_session():
+    with self.cached_session():
       reg = InheritClassBuildFnReg(
           build_fn=None,
           hidden_dim=HIDDEN_DIM,
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 6bba99b..17831fa 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -538,6 +538,21 @@
 )
 
 tf_py_test(
+    name = "logging_ops_logging_level_test",
+    size = "small",
+    srcs = ["logging_ops_logging_level_test.py"],
+    additional_deps = [
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:logging_ops",
+    ],
+    tags = [
+        "no_windows",
+    ],
+)
+
+tf_py_test(
     name = "logging_ops_test",
     size = "small",
     srcs = ["logging_ops_test.py"],
@@ -961,6 +976,19 @@
 )
 
 tf_py_test(
+    name = "string_format_op_test",
+    size = "small",
+    srcs = ["string_format_op_test.py"],
+    additional_deps = [
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:string_ops",
+        "//tensorflow/python:math_ops",
+    ],
+)
+
+tf_py_test(
     name = "string_join_op_test",
     size = "small",
     srcs = ["string_join_op_test.py"],
@@ -2799,6 +2827,46 @@
 )
 
 cuda_py_test(
+    name = "cwise_ops_binary_test",
+    size = "medium",
+    srcs = ["cwise_ops_binary_test.py"],
+    additional_deps = [
+        "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:gradients",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:math_ops_gen",
+        "//tensorflow/python:nn_grad",
+        "//tensorflow/python:platform",
+        "//tensorflow/python:variables",
+    ],
+    shard_count = 50,
+)
+
+cuda_py_test(
+    name = "cwise_ops_unary_test",
+    size = "medium",
+    srcs = ["cwise_ops_unary_test.py"],
+    additional_deps = [
+        "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:gradients",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:math_ops_gen",
+        "//tensorflow/python:nn_grad",
+        "//tensorflow/python:platform",
+        "//tensorflow/python:variables",
+    ],
+    shard_count = 50,
+)
+
+cuda_py_test(
     name = "embedding_ops_test",
     size = "medium",
     srcs = ["embedding_ops_test.py"],
@@ -3164,3 +3232,27 @@
     grpc_enabled = True,
     tags = ["no_gpu"],  # TODO(b/111656070)
 )
+
+# TODO(b/116053459): Replace with cuda_py_test.
+tf_py_test(
+    name = "while_v2_test",
+    size = "medium",
+    srcs = ["while_v2_test.py"],
+    additional_deps = [
+        "@absl_py//absl/testing:parameterized",
+        "//tensorflow/core:protos_all_py",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:gradients_impl",
+        "//tensorflow/python:list_ops",
+        "//tensorflow/python:tf_optimizer",
+        "//tensorflow/python:while_v2",
+    ],
+    grpc_enabled = True,
+    tags = ["no_gpu"],  # TODO(b/116053459)
+)
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 573bb86..2fe8583 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1276,5 +1276,203 @@
         self.assertAllEqual(y.eval(), [0, 1, 2, 3])
 
 
+@test_util.run_all_in_graph_and_eager_modes
+class SortedSearchTest(test_util.TensorFlowTestCase):
+
+  def testUpperBoundFloatHandCoded(self):
+    cdf = np.array([0, .2, .5, .6, .8, 1.], dtype=np.float32)
+    arr = np.array([.04, .99, .53, .58, .31, .01, .79, .8, .21],
+                   dtype=np.float32)
+    result = np.searchsorted(cdf, arr, side="right")
+    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+    self.assertAllEqual(result, tf_result)
+
+  def testUpperBoundFloatRandomNd(self):
+    dim_size = 7
+    for d in range(1, 5):
+      shape = [dim_size] * d
+      cdf = np.cumsum(
+          np.random.uniform(size=shape).astype(np.float32), axis=(d - 1))
+      arr = np.random.uniform(size=shape).astype(np.float32) * dim_size
+
+      tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+      cdf = cdf.reshape([-1, dim_size])
+      arr = arr.reshape([-1, dim_size])
+      result = np.zeros(arr.shape, dtype=np.int32)
+      for i in range(dim_size**(d - 1)):
+        result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+      result = result.reshape(shape)
+
+      self.assertAllEqual(result, tf_result)
+
+  def testUpperBoundFloatUneven(self):
+    batch_size = 7
+    size_search_array = 1000
+    size_values = 47
+    cdf = np.cumsum(
+        np.random.uniform(size=[batch_size, size_search_array]).astype(
+            np.float32),
+        axis=1)
+    arr = np.random.uniform(size=[batch_size, size_values]).astype(
+        np.float32) * size_search_array
+
+    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+    result = np.zeros(arr.shape, dtype=np.int32)
+    for i in range(batch_size):
+      result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+    self.assertAllEqual(result, tf_result)
+
+  def testLowerBoundFloatHandCoded(self):
+    cdf = np.array([0, .2, .5, .6, .8, 1.], dtype=np.float32)
+    arr = np.array([.04, .99, .53, .58, .31, .01, .79, .8, .21],
+                   dtype=np.float32)
+    result = np.searchsorted(cdf, arr, side="left")
+    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+    self.assertAllEqual(result, tf_result)
+
+  def testLowerBoundFloatRandomNd(self):
+    dim_size = 7
+    for d in range(1, 5):
+      shape = [dim_size] * d
+      cdf = np.cumsum(
+          np.random.uniform(size=shape).astype(np.float32), axis=(d - 1))
+      arr = np.random.uniform(size=shape).astype(np.float32) * dim_size
+
+      tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+      cdf = cdf.reshape([-1, dim_size])
+      arr = arr.reshape([-1, dim_size])
+      result = np.zeros(arr.shape, dtype=np.int32)
+      for i in range(dim_size**(d - 1)):
+        result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+      result = result.reshape(shape)
+
+      self.assertAllEqual(result, tf_result)
+
+  def testLowerBoundFloatUneven(self):
+    batch_size = 7
+    size_search_array = 1000
+    size_values = 47
+    cdf = np.cumsum(
+        np.random.uniform(size=[batch_size, size_search_array]).astype(
+            np.float32),
+        axis=1)
+    arr = np.random.uniform(size=[batch_size, size_values]).astype(
+        np.float32) * size_search_array
+
+    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+    result = np.zeros(arr.shape, dtype=np.int32)
+    for i in range(batch_size):
+      result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+    self.assertAllEqual(result, tf_result)
+
+  def testUpperBoundIntHandCoded(self):
+    cdf = np.array([0, 20, 50, 60, 80, 100], dtype=np.int64)
+    arr = np.array([4, 99, 53, 58, 31, 1, 79, 8, 21], dtype=np.int64)
+    result = np.searchsorted(cdf, arr, side="right")
+    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+    self.assertAllEqual(result, tf_result)
+
+  def testUpperBoundIntRandomNd(self):
+    dim_size = 7
+    for d in range(1, 5):
+      shape = [dim_size] * d
+      cdf = np.cumsum(
+          np.random.randint(low=0, high=10, size=shape).astype(np.int64),
+          axis=(d - 1))
+      arr = np.random.randint(
+          low=0, high=10 * dim_size, size=shape).astype(np.int64)
+
+      tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+      cdf = cdf.reshape([-1, dim_size])
+      arr = arr.reshape([-1, dim_size])
+      result = np.zeros(arr.shape, dtype=np.int32)
+      for i in range(dim_size**(d - 1)):
+        result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+      result = result.reshape(shape)
+
+      self.assertAllEqual(result, tf_result)
+
+  def testUpperBoundIntUneven(self):
+    batch_size = 7
+    size_search_array = 1000
+    size_values = 47
+    cdf = np.cumsum(
+        np.random.randint(low=0, high=10,
+                          size=[batch_size,
+                                size_search_array]).astype(np.int64),
+        axis=1)
+    arr = np.random.randint(
+        low=0, high=10 * size_search_array, size=[batch_size,
+                                                  size_values]).astype(np.int64)
+
+    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+    result = np.zeros(arr.shape, dtype=np.int32)
+    for i in range(batch_size):
+      result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+    self.assertAllEqual(result, tf_result)
+
+  def testLowerBoundIntHandCoded(self):
+    cdf = np.array([0, 20, 50, 60, 80, 100], dtype=np.int64)
+    arr = np.array([4, 99, 53, 58, 31, 1, 79, 8, 21], dtype=np.int64)
+    result = np.searchsorted(cdf, arr, side="left")
+    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+    self.assertAllEqual(result, tf_result)
+
+  def testLowerBoundIntRandomNd(self):
+    dim_size = 7
+    for d in range(1, 5):
+      shape = [dim_size] * d
+      cdf = np.cumsum(
+          np.random.randint(low=0, high=10, size=shape).astype(np.int64),
+          axis=(d - 1))
+      arr = np.random.randint(
+          low=0, high=10 * dim_size, size=shape).astype(np.int64)
+
+      tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+      cdf = cdf.reshape([-1, dim_size])
+      arr = arr.reshape([-1, dim_size])
+      result = np.zeros(arr.shape, dtype=np.int32)
+      for i in range(dim_size**(d - 1)):
+        result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+      result = result.reshape(shape)
+
+      self.assertAllEqual(result, tf_result)
+
+  def testLowerBoundIntUneven(self):
+    batch_size = 7
+    size_search_array = 1000
+    size_values = 47
+    cdf = np.cumsum(
+        np.random.randint(low=0, high=10,
+                          size=[batch_size,
+                                size_search_array]).astype(np.int64),
+        axis=1)
+    arr = np.random.randint(
+        low=0, high=10 * size_search_array, size=[batch_size,
+                                                  size_values]).astype(np.int64)
+
+    tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+    result = np.zeros(arr.shape, dtype=np.int32)
+    for i in range(batch_size):
+      result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+    self.assertAllEqual(result, tf_result)
+
+
 if __name__ == "__main__":
   test_lib.main()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
index dee9610..3b28d44 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
@@ -928,6 +928,163 @@
 class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
   """Tests feature contribs ops for model understanding."""
 
+  def testContribsForOnlyABiasNode(self):
+    """Tests case when, after training, only left with a bias node.
+
+    For example, this could happen if the final ensemble contains one tree that
+    got pruned up to the root.
+    """
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge(
+          """
+        trees {
+          nodes {
+            leaf {
+              scalar: 1.72
+            }
+          }
+        }
+        tree_weights: 0.1
+        tree_metadata: {
+          num_layers_grown: 0
+        }
+      """, tree_ensemble_config)
+
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      # All features are unused.
+      feature_0_values = [36, 32]
+      feature_1_values = [13, -29]
+      feature_2_values = [11, 27]
+
+      # Expected logits are computed by traversing the logit path and
+      # subtracting child logits from parent logits.
+      bias = 1.72 * 0.1  # Root node of tree_0.
+      expected_feature_ids = ((), ())
+      expected_logits_paths = ((bias,), (bias,))
+
+      bucketized_features = [
+          feature_0_values, feature_1_values, feature_2_values
+      ]
+
+      debug_op = boosted_trees_ops.example_debug_outputs(
+          tree_ensemble_handle,
+          bucketized_features=bucketized_features,
+          logits_dimension=1)
+
+      serialized_examples_debug_outputs = session.run(debug_op)
+      feature_ids = []
+      logits_paths = []
+      for example in serialized_examples_debug_outputs:
+        example_debug_outputs = boosted_trees_pb2.DebugOutput()
+        example_debug_outputs.ParseFromString(example)
+        feature_ids.append(example_debug_outputs.feature_ids)
+        logits_paths.append(example_debug_outputs.logits_path)
+
+      self.assertAllClose(feature_ids, expected_feature_ids)
+      self.assertAllClose(logits_paths, expected_logits_paths)
+
+  def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self):
+    """Tests case when, after training, first tree contains only a bias node."""
+    with self.test_session() as session:
+      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+      text_format.Merge(
+          """
+        trees {
+          nodes {
+            leaf {
+              scalar: 1.72
+            }
+          }
+        }
+        trees {
+          nodes {
+            bucketized_split {
+              feature_id: 2
+              threshold: 26
+              left_id: 1
+              right_id: 2
+            }
+          }
+          nodes {
+            bucketized_split {
+              feature_id: 0
+              threshold: 50
+              left_id: 3
+              right_id: 4
+            }
+            metadata {
+              original_leaf: {scalar: 5.5}
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 7.0
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 5.0
+            }
+          }
+          nodes {
+            leaf {
+              scalar: 6.0
+            }
+          }
+        }
+        tree_weights: 1.
+        tree_weights: 0.1
+        tree_metadata: {
+          num_layers_grown: 0
+        }
+        tree_metadata: {
+          num_layers_grown: 1
+        }
+      """, tree_ensemble_config)
+
+      tree_ensemble = boosted_trees_ops.TreeEnsemble(
+          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+      tree_ensemble_handle = tree_ensemble.resource_handle
+      resources.initialize_resources(resources.shared_resources()).run()
+
+      feature_0_values = [36, 32]
+      feature_1_values = [13, -29]  # Unused feature.
+      feature_2_values = [11, 27]
+
+      # Expected logits are computed by traversing the logit path and
+      # subtracting child logits from parent logits.
+      expected_feature_ids = ((2, 0), (2,))
+      # bias = 1.72 * 1.  # Root node of tree_0.
+      # example_0 :  (bias, 0.1 * 5.5 + bias, 0.1 * 5. + bias)
+      # example_1 :  (bias, 0.1 * 7. + bias )
+      expected_logits_paths = ((1.72, 2.27, 2.22), (1.72, 2.42))
+
+      bucketized_features = [
+          feature_0_values, feature_1_values, feature_2_values
+      ]
+
+      debug_op = boosted_trees_ops.example_debug_outputs(
+          tree_ensemble_handle,
+          bucketized_features=bucketized_features,
+          logits_dimension=1)
+
+      serialized_examples_debug_outputs = session.run(debug_op)
+      feature_ids = []
+      logits_paths = []
+      for example in serialized_examples_debug_outputs:
+        example_debug_outputs = boosted_trees_pb2.DebugOutput()
+        example_debug_outputs.ParseFromString(example)
+        feature_ids.append(example_debug_outputs.feature_ids)
+        logits_paths.append(example_debug_outputs.logits_path)
+
+      self.assertAllClose(feature_ids, expected_feature_ids)
+      self.assertAllClose(logits_paths, expected_logits_paths)
+
   def testContribsMultipleTree(self):
     """Tests that the contribs work when we have multiple trees."""
     with self.cached_session() as session:
@@ -1018,11 +1175,14 @@
         tree_weights: 0.2
         tree_weights: 1.0
         tree_metadata: {
-          num_layers_grown: 1}
+          num_layers_grown: 1
+        }
         tree_metadata: {
-          num_layers_grown: 2}
+          num_layers_grown: 2
+        }
         tree_metadata: {
-          num_layers_grown: 1}
+          num_layers_grown: 1
+        }
       """, tree_ensemble_config)
 
       tree_ensemble = boosted_trees_ops.TreeEnsemble(
diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
index bd2339f..09c325f 100644
--- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
+++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
@@ -90,7 +90,7 @@
       x = constant_op.constant(1, dtype=dtypes.float32)
       v = array_ops.broadcast_to(x, [2, 4, 3])
       out = 2 * v
-      with self.test_session():
+      with self.cached_session():
         err = gradient_checker.compute_gradient_error(x, x.get_shape(),
                                                       out, out.get_shape())
     self.assertLess(err, 1e-4)
@@ -100,7 +100,7 @@
                              dtype=dtypes.float32)
     v = array_ops.broadcast_to(x, [2, 5, 3])
     out = 2 * v
-    with self.test_session():
+    with self.cached_session():
       err = gradient_checker.compute_gradient_error(x, x.get_shape(),
                                                     out, out.get_shape())
     self.assertLess(err, 1e-4)
@@ -110,7 +110,7 @@
                              dtype=dtypes.float32)
     v = array_ops.broadcast_to(x, [5, 2, 3])
     out = 2 * v
-    with self.test_session():
+    with self.cached_session():
       err = gradient_checker.compute_gradient_error(x, x.get_shape(),
                                                     out, out.get_shape())
     self.assertLess(err, 1e-4)
@@ -119,7 +119,7 @@
     x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32)
     v = array_ops.broadcast_to(x, [5, 4, 6])
     out = 2 * v
-    with self.test_session():
+    with self.cached_session():
       err = gradient_checker.compute_gradient_error(x, x.get_shape(),
                                                     out, out.get_shape())
     self.assertLess(err, 1e-4)
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index 27a674e..bd4011d 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -785,7 +785,7 @@
     derived = math_ops.divide(placeholder, 3, name="MyDivide")
     derived = check_ops.ensure_shape(derived, (3, 3, 3))
     feed_val = [[1], [2]]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaisesWithPredicateMatch(
           errors.InvalidArgumentError,
           r"Shape of tensor MyDivide \[2,1\] is not compatible with "
@@ -797,7 +797,7 @@
     derived = placeholder / 3
     derived = check_ops.ensure_shape(derived, (None, None, 3))
     feed_val = [[1], [2]]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaisesWithPredicateMatch(
           errors.InvalidArgumentError,
           r"Shape of tensor [A-Za-z_]* \[2,1\] is not compatible with "
@@ -809,7 +809,7 @@
     derived = placeholder / 3
     derived = check_ops.ensure_shape(derived, (2, 1))
     feed_val = [[1], [2]]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(derived, feed_dict={placeholder: feed_val})
 
   def testEnsuresDynamicShape_WithUnknownDims(self):
@@ -817,7 +817,7 @@
     derived = placeholder / 3
     derived = check_ops.ensure_shape(derived, (None, None))
     feed_val = [[1], [2]]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(derived, feed_dict={placeholder: feed_val})
 
   def testGradient(self):
@@ -826,7 +826,7 @@
     gradient = gradients.gradients(derived, placeholder)
 
     feed_val = [[4.0], [-1.0]]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       gradient_values, = sess.run(gradient, feed_dict={placeholder: feed_val})
 
     expected = [[1.0], [1.0]]
diff --git a/tensorflow/python/kernel_tests/conditional_accumulator_test.py b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
index 262352a..97ab23f 100644
--- a/tensorflow/python/kernel_tests/conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
@@ -272,7 +272,7 @@
       self.assertEqual(15.0, val)
 
   def testAccumulatorTakeGradSum(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32,
           name="Q",
@@ -349,7 +349,7 @@
       self.assertEqual(elems_ave + 0.0, val)
 
   def testAccumulatorRepeatedTakeGradSum(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32,
           name="Q",
diff --git a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py
new file mode 100644
index 0000000..8028f93
--- /dev/null
+++ b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py
@@ -0,0 +1,878 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for binary coefficient-wise operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+_ADD = lambda x, y: x + y
+_SUB = lambda x, y: x - y
+_MUL = lambda x, y: x * y
+_POW = lambda x, y: x**y
+_TRUEDIV = lambda x, y: x / y
+_FLOORDIV = lambda x, y: x // y
+_MOD = lambda x, y: x % y
+
+
+# TODO(zongheng): it'd be great to factor out this function and various random
+# SparseTensor gen funcs.
+def _sparsify(x, thresh=0.5, index_dtype=np.int64):
+  x[x < thresh] = 0
+
+  non_zero = np.where(x)
+  x_indices = np.vstack(non_zero).astype(index_dtype).T
+  x_values = x[non_zero]
+  x_shape = x.shape
+
+  return sparse_tensor.SparseTensor(
+      indices=x_indices, values=x_values, dense_shape=x_shape), x_values
+
+
+def _default_tolerance(dtype):
+  """Returns a sensible default tolerance for comparing results of a given type.
+
+  Args:
+    dtype: A datatype.
+  """
+  if dtype == np.float16:
+    return 5e-3
+  elif dtype in (np.float32, np.complex64):
+    return 1e-3
+  elif dtype in (np.float64, np.complex128):
+    return 1e-5
+  else:
+    return None  # Fail fast for unexpected types
+
+
+class BinaryOpTest(test.TestCase):
+
+  def _compareCpu(self, x, y, np_func, tf_func, also_compare_variables=False):
+    np_ans = np_func(x, y)
+    with self.test_session(use_gpu=False):
+      inx = ops.convert_to_tensor(x)
+      iny = ops.convert_to_tensor(y)
+      out = tf_func(inx, iny)
+      tf_cpu = out.eval()
+      # Test that the op takes precedence over numpy operators.
+      np_left = tf_func(x, iny).eval()
+      np_right = tf_func(inx, y).eval()
+
+      if also_compare_variables:
+        var_x = variables.Variable(x)
+        var_y = variables.Variable(y)
+        variables.global_variables_initializer().run()
+        print(type(x), type(y), type(var_x), type(var_y))
+        print(type(tf_func(x, var_y)), type(tf_func(var_x, y)))
+        np_var_left = tf_func(x, var_y).eval()
+        np_var_right = tf_func(var_x, y).eval()
+
+    if np_ans.dtype != np.object:
+      self.assertAllClose(np_ans, tf_cpu)
+      self.assertAllClose(np_ans, np_left)
+      self.assertAllClose(np_ans, np_right)
+      if also_compare_variables:
+        self.assertAllClose(np_ans, np_var_left)
+        self.assertAllClose(np_ans, np_var_right)
+    self.assertShapeEqual(np_ans, out)
+
+  _GRAD_TOL = {
+      dtypes_lib.float16: 1e-3,
+      dtypes_lib.float32: 1e-3,
+      dtypes_lib.complex64: 1e-2,
+      dtypes_lib.float64: 1e-5,
+      dtypes_lib.complex128: 1e-4
+  }
+
+  def _compareGradientX(self,
+                        x,
+                        y,
+                        np_func,
+                        tf_func,
+                        numeric_gradient_type=None):
+    z = np_func(x, y)
+    zs = list(z.shape)
+    with self.cached_session():
+      inx = ops.convert_to_tensor(x)
+      iny = ops.convert_to_tensor(y)
+      if x.dtype in (np.float32, np.float64):
+        out = 1.1 * tf_func(inx, iny)
+      else:
+        out = tf_func(inx, iny)
+      xs = list(x.shape)
+      jacob_t, jacob_n = gradient_checker.compute_gradient(
+          inx, xs, out, zs, x_init_value=x)
+      if numeric_gradient_type is not None:
+        xf = x.astype(numeric_gradient_type)
+        yf = y.astype(numeric_gradient_type)
+        inxf = ops.convert_to_tensor(xf)
+        inyf = ops.convert_to_tensor(yf)
+        outf = tf_func(inxf, inyf)
+        _, jacob_n = gradient_checker.compute_gradient(
+            inxf, xs, outf, zs, x_init_value=xf, delta=1e-3)
+        jacob_n = jacob_n.astype(x.dtype)
+      tol = self._GRAD_TOL[dtypes_lib.as_dtype(x.dtype)]
+      self.assertAllClose(jacob_t, jacob_n, rtol=tol, atol=tol)
+
+  def _compareGradientY(self,
+                        x,
+                        y,
+                        np_func,
+                        tf_func,
+                        numeric_gradient_type=None):
+    z = np_func(x, y)
+    zs = list(z.shape)
+    with self.cached_session():
+      inx = ops.convert_to_tensor(x)
+      iny = ops.convert_to_tensor(y)
+      if x.dtype in (np.float32, np.float64):
+        out = 1.1 * tf_func(inx, iny)
+      else:
+        out = tf_func(inx, iny)
+      ys = list(np.shape(y))
+      jacob_t, jacob_n = gradient_checker.compute_gradient(
+          iny, ys, out, zs, x_init_value=y)
+      if numeric_gradient_type is not None:
+        xf = x.astype(numeric_gradient_type)
+        yf = y.astype(numeric_gradient_type)
+        inxf = ops.convert_to_tensor(xf)
+        inyf = ops.convert_to_tensor(yf)
+        outf = tf_func(inxf, inyf)
+        _, jacob_n = gradient_checker.compute_gradient(
+            inyf, ys, outf, zs, x_init_value=yf)
+        jacob_n = jacob_n.astype(x.dtype)
+    tol = self._GRAD_TOL[dtypes_lib.as_dtype(x.dtype)]
+    self.assertAllClose(jacob_t, jacob_n, rtol=tol, atol=tol)
+
+  def _compareGpu(self, x, y, np_func, tf_func):
+    np_ans = np_func(x, y)
+    with self.test_session(force_gpu=test_util.is_gpu_available()):
+      inx = ops.convert_to_tensor(x)
+      iny = ops.convert_to_tensor(y)
+      out = tf_func(inx, iny)
+      tf_gpu = out.eval()
+    self.assertAllClose(np_ans, tf_gpu)
+    self.assertShapeEqual(np_ans, out)
+    # TODO(zhifengc/ke): make gradient checker work on GPU.
+
+  def _compareBoth(self, x, y, np_func, tf_func, also_compare_variables=False):
+    self._compareCpu(x, y, np_func, tf_func, also_compare_variables)
+    if x.dtype in (np.float16, np.float32, np.float64, np.complex64,
+                   np.complex128):
+      if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.zeta,
+                         math_ops.polygamma):
+        self._compareGradientX(x, y, np_func, tf_func)
+        self._compareGradientY(x, y, np_func, tf_func)
+      if tf_func in (math_ops.zeta, math_ops.polygamma):
+        # These methods only support gradients in the second parameter
+        self._compareGradientY(x, y, np_func, tf_func)
+      self._compareGpu(x, y, np_func, tf_func)
+
+  def testFloatBasic(self):
+    x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float32)
+    y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float32)
+    self._compareBoth(x, y, np.add, math_ops.add, also_compare_variables=True)
+    self._compareBoth(x, y, np.subtract, math_ops.subtract)
+    self._compareBoth(x, y, np.multiply, math_ops.multiply)
+    self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
+    self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
+    self._compareBoth(x, y, np.add, _ADD)
+    self._compareBoth(x, y, np.subtract, _SUB)
+    self._compareBoth(x, y, np.multiply, _MUL)
+    self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+    self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
+    self._compareBoth(x, y, np.arctan2, math_ops.atan2)
+    x1 = np.random.randn(5, 6).astype(np.float32)
+    x2 = np.random.randn(5, 6).astype(np.float32)
+    # Remove tiny values--atan2 gradients are flaky near the origin.
+    x1[np.abs(x1) < 0.05] = 0.05 * np.sign(x1[np.abs(x1) < 0.05])
+    x2[np.abs(x2) < 0.05] = 0.05 * np.sign(x2[np.abs(x2) < 0.05])
+    self._compareBoth(x1, x2, np.arctan2, math_ops.atan2)
+    try:
+      from scipy import special  # pylint: disable=g-import-not-at-top
+      a_pos_small = np.linspace(0.1, 2, 15).reshape(1, 3, 5).astype(np.float32)
+      x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32)
+      self._compareBoth(a_pos_small, x_pos_small, special.gammainc,
+                        math_ops.igamma)
+      self._compareBoth(a_pos_small, x_pos_small, special.gammaincc,
+                        math_ops.igammac)
+      # Need x > 1
+      self._compareBoth(x_pos_small + 1, a_pos_small, special.zeta,
+                        math_ops.zeta)
+      n_small = np.arange(0, 15).reshape(1, 3, 5).astype(np.float32)
+      self._compareBoth(n_small, x_pos_small, special.polygamma,
+                        math_ops.polygamma)
+    except ImportError as e:
+      tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+  def testFloatDifferentShapes(self):
+    x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.float32)
+    y = np.array([1, 2]).reshape(2, 1).astype(np.float32)
+    with self.cached_session() as sess:
+      inx = ops.convert_to_tensor(x)
+      iny = ops.convert_to_tensor(y)
+      s = math_ops.reduce_sum(inx * iny)
+      gx, gy = sess.run(gradients_impl.gradients(s, [inx, iny]))
+    # gx is simply the broadcasted y
+    self.assertAllEqual(gx,
+                        np.array([1, 1, 2, 2]).reshape(2, 2).astype(np.float32))
+    # gy is x's column summed up
+    self.assertAllEqual(gy, np.array([3, 7]).reshape(2, 1).astype(np.float32))
+
+  def testFloatVariableOverload(self):
+    x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.int32)
+    y = np.array([1, 2]).reshape(2, 1).astype(np.int32)
+    var_x = variables.Variable(x)
+    var_y = variables.Variable(y)
+    with self.cached_session() as sess:
+      sess.run([var_x.initializer, var_y.initializer])
+      left_result = (var_x * y).eval()
+      right_result = (x * var_y).eval()
+    np_result = x * y
+    self.assertAllEqual(np_result, left_result)
+    self.assertAllEqual(np_result, right_result)
+
+  def testDoubleBasic(self):
+    x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float64)
+    y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float64)
+    self._compareBoth(x, y, np.add, math_ops.add)
+    self._compareBoth(x, y, np.subtract, math_ops.subtract)
+    self._compareBoth(x, y, np.multiply, math_ops.multiply)
+    self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
+    self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
+    self._compareBoth(x, y, np.add, _ADD)
+    self._compareBoth(x, y, np.subtract, _SUB)
+    self._compareBoth(x, y, np.multiply, _MUL)
+    self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+    self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
+    self._compareBoth(x, y, np.arctan2, math_ops.atan2)
+    x1 = np.random.randn(7, 4).astype(np.float64)
+    x2 = np.random.randn(7, 4).astype(np.float64)
+    # Remove tiny values--atan2 gradients are flaky near the origin.
+    x1[np.abs(x1) < 0.5] = 0.5 * np.sign(x1[np.abs(x1) < 0.5])
+    x2[np.abs(x2) < 0.5] = 0.5 * np.sign(x2[np.abs(x2) < 0.5])
+    self._compareBoth(x1, x2, np.arctan2, math_ops.atan2)
+    try:
+      from scipy import special  # pylint: disable=g-import-not-at-top
+      a_pos_small = np.linspace(0.1, 2, 15).reshape(1, 3, 5).astype(np.float32)
+      x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32)
+      self._compareBoth(a_pos_small, x_pos_small, special.gammainc,
+                        math_ops.igamma)
+      self._compareBoth(a_pos_small, x_pos_small, special.gammaincc,
+                        math_ops.igammac)
+    except ImportError as e:
+      tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+  def testUint8Basic(self):
+    x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint8)
+    y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint8)
+    self._compareBoth(x, y, np.add, math_ops.add)
+
+  def testInt8Basic(self):
+    x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8)
+    y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int8)
+    self._compareBoth(x, y, np.multiply, math_ops.multiply)
+    self._compareBoth(x, y, np.multiply, _MUL)
+
+  def testInt16Basic(self):
+    x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int16)
+    y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int16)
+    self._compareBoth(x, y, np.multiply, math_ops.multiply)
+    self._compareBoth(x, y, np.multiply, _MUL)
+
+  def testUint16Basic(self):
+    x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint16)
+    y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint16)
+    self._compareBoth(x, y, np.multiply, math_ops.multiply)
+    self._compareBoth(x, y, np.multiply, _MUL)
+    self._compareBoth(x, y, np.true_divide, math_ops.truediv)
+    self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
+    self._compareBoth(x, y, np.true_divide, _TRUEDIV)
+    self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
+
+  def testInt32Basic(self):
+    x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int32)
+    y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int32)
+    self._compareBoth(x, y, np.add, math_ops.add)
+    self._compareBoth(x, y, np.subtract, math_ops.subtract)
+    self._compareBoth(x, y, np.multiply, math_ops.multiply)
+    self._compareBoth(x, y, np.true_divide, math_ops.truediv)
+    self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
+    self._compareBoth(x, y, np.mod, math_ops.mod)
+    self._compareBoth(x, y, np.add, _ADD)
+    self._compareBoth(x, y, np.subtract, _SUB)
+    self._compareBoth(x, y, np.multiply, _MUL)
+    self._compareBoth(x, y, np.true_divide, _TRUEDIV)
+    self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
+    self._compareBoth(x, y, np.mod, _MOD)
+    # _compareBoth tests on GPU only for floating point types, so test
+    # _MOD for int32 on GPU by calling _compareGpu
+    self._compareGpu(x, y, np.mod, _MOD)
+
+  def testInt64Basic(self):
+    x = np.arange(1 << 40, 13 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
+    y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int64)
+    self._compareBoth(x, y, np.subtract, math_ops.subtract)
+    self._compareBoth(x, y, np.multiply, math_ops.multiply)
+    self._compareBoth(x, y, np.true_divide, math_ops.truediv)
+    self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
+    self._compareBoth(x, y, np.mod, math_ops.mod)
+    self._compareBoth(x, y, np.subtract, _SUB)
+    self._compareBoth(x, y, np.multiply, _MUL)
+    self._compareBoth(x, y, np.true_divide, _TRUEDIV)
+    self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
+    self._compareBoth(x, y, np.mod, _MOD)
+
+  def testComplex64Basic(self):
+    x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
+        np.complex64)
+    y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
+        np.complex64)
+    self._compareBoth(x, y, np.add, math_ops.add)
+    self._compareBoth(x, y, np.subtract, math_ops.subtract)
+    self._compareBoth(x, y, np.multiply, math_ops.multiply)
+    self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
+    self._compareBoth(x, y, np.add, _ADD)
+    self._compareBoth(x, y, np.subtract, _SUB)
+    self._compareBoth(x, y, np.multiply, _MUL)
+    self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+
+  def testComplex128Basic(self):
+    x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
+        np.complex128)
+    y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
+        np.complex128)
+    self._compareBoth(x, y, np.add, math_ops.add)
+    self._compareBoth(x, y, np.subtract, math_ops.subtract)
+    self._compareBoth(x, y, np.multiply, math_ops.multiply)
+    self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
+    self._compareBoth(x, y, np.add, _ADD)
+    self._compareBoth(x, y, np.subtract, _SUB)
+    self._compareBoth(x, y, np.multiply, _MUL)
+    self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+
+  def testStringComparison(self):
+    x = np.array([["abc", "bh"], ["c", ""]])
+    y = np.array([["abc", "bh"], ["def", "hi"]])
+    with self.test_session(use_gpu=False) as sess:
+      cmp_eq = math_ops.equal(x, y)
+      cmp_not_eq = math_ops.not_equal(x, y)
+      values = sess.run([cmp_eq, cmp_not_eq])
+      self.assertAllEqual([[True, True], [False, False]], values[0])
+      self.assertAllEqual([[False, False], [True, True]], values[1])
+
+  def testString(self):
+    x = np.array([["x_0_0", "x_0_1", "x_0_2"], ["x_1_0", "x_1_1", "x_1_2"],
+                  ["x_2_0", "x_2_1", "x_2_2"]],
+                 dtype=np.object)
+    y = np.array([["y_0_0", "y_0_1", "y_0_2"], ["y_1_0", "y_1_1", "y_1_2"],
+                  ["y_2_0", "y_2_1", "y_2_2"]],
+                 dtype=np.object)
+    z = np.array([["z_0", "z_1", "z_2"]], dtype=np.object)
+    w = np.array("w", dtype=np.object)
+    self._compareCpu(x, y, _ADD, _ADD)
+    self._compareCpu(x, z, _ADD, _ADD)
+    self._compareCpu(x, w, _ADD, _ADD)
+    self._compareCpu(z, w, _ADD, _ADD)
+
+  def _compareBCast(self, xs, ys, dtype, np_func, tf_func):
+    if dtype in (np.complex64, np.complex128):
+      x = (1 + np.linspace(0, 2 + 3j, np.prod(xs))).astype(dtype).reshape(xs)
+      y = (1 + np.linspace(0, 2 - 2j, np.prod(ys))).astype(dtype).reshape(ys)
+    else:
+      x = (1 + np.linspace(0, 5, np.prod(xs))).astype(dtype).reshape(xs)
+      y = (1 + np.linspace(0, 5, np.prod(ys))).astype(dtype).reshape(ys)
+    self._compareCpu(x, y, np_func, tf_func)
+    if x.dtype in (np.float16, np.float32, np.float64):
+      # TODO(aselle): Make the test work for dtypes:
+      #     (np.complex64, np.complex128).
+      if tf_func not in (_FLOORDIV, math_ops.floordiv):
+        if x.dtype == np.float16:
+          # Compare fp16 theoretical gradients to fp32 numerical gradients,
+          # since fp16 numerical gradients are too imprecise unless great
+          # care is taken with choosing the inputs and the delta. This is
+          # a weaker check (in particular, it does not test the op itself,
+          # only its gradient), but it's much better than nothing.
+          self._compareGradientX(x, y, np_func, tf_func, np.float)
+          self._compareGradientY(x, y, np_func, tf_func, np.float)
+        else:
+          self._compareGradientX(x, y, np_func, tf_func)
+          self._compareGradientY(x, y, np_func, tf_func)
+      self._compareGpu(x, y, np_func, tf_func)
+
+  # TODO(josh11b,vrv): Refactor this to use parameterized tests.
+  def _testBCastByFunc(self, funcs, xs, ys):
+    dtypes = [
+        np.float16,
+        np.float32,
+        np.float64,
+        np.int32,
+        np.int64,
+        np.complex64,
+        np.complex128,
+    ]
+    for dtype in dtypes:
+      for (np_func, tf_func) in funcs:
+        if (dtype in (np.complex64, np.complex128) and
+            tf_func in (_FLOORDIV, math_ops.floordiv)):
+          continue  # floordiv makes no sense for complex numbers
+        self._compareBCast(xs, ys, dtype, np_func, tf_func)
+        self._compareBCast(ys, xs, dtype, np_func, tf_func)
+
+  def _testBCastA(self, xs, ys):
+    funcs = [
+        (np.add, math_ops.add),
+        (np.add, _ADD),
+    ]
+    self._testBCastByFunc(funcs, xs, ys)
+
+  def _testBCastB(self, xs, ys):
+    funcs = [
+        (np.subtract, math_ops.subtract),
+        (np.subtract, _SUB),
+        (np.power, math_ops.pow),
+    ]
+    self._testBCastByFunc(funcs, xs, ys)
+
+  def _testBCastC(self, xs, ys):
+    funcs = [
+        (np.multiply, math_ops.multiply),
+        (np.multiply, _MUL),
+    ]
+    self._testBCastByFunc(funcs, xs, ys)
+
+  def _testBCastD(self, xs, ys):
+    funcs = [
+        (np.true_divide, math_ops.truediv),
+        (np.floor_divide, math_ops.floordiv),
+        (np.true_divide, _TRUEDIV),
+        (np.floor_divide, _FLOORDIV),
+    ]
+    self._testBCastByFunc(funcs, xs, ys)
+
+  def testBCast_0A(self):
+    self._testBCastA([1, 3, 2], [1])
+
+  def testBCast_0B(self):
+    self._testBCastB([1, 3, 2], [1])
+
+  def testBCast_0C(self):
+    self._testBCastC([1, 3, 2], [1])
+
+  def testBCast_0D(self):
+    self._testBCastD([1, 3, 2], [1])
+
+  def testBCast_1A(self):
+    self._testBCastA([1, 3, 2], [2])
+
+  def testBCast_1B(self):
+    self._testBCastB([1, 3, 2], [2])
+
+  def testBCast_1C(self):
+    self._testBCastC([1, 3, 2], [2])
+
+  def testBCast_1D(self):
+    self._testBCastD([1, 3, 2], [2])
+
+  def testBCast_2A(self):
+    self._testBCastA([1, 3, 2], [3, 2])
+
+  def testBCast_2B(self):
+    self._testBCastB([1, 3, 2], [3, 2])
+
+  def testBCast_2C(self):
+    self._testBCastC([1, 3, 2], [3, 2])
+
+  def testBCast_2D(self):
+    self._testBCastD([1, 3, 2], [3, 2])
+
+  def testBCast_3A(self):
+    self._testBCastA([1, 3, 2], [3, 1])
+
+  def testBCast_3B(self):
+    self._testBCastB([1, 3, 2], [3, 1])
+
+  def testBCast_3C(self):
+    self._testBCastC([1, 3, 2], [3, 1])
+
+  def testBCast_3D(self):
+    self._testBCastD([1, 3, 2], [3, 1])
+
+  def testBCast_4A(self):
+    self._testBCastA([1, 3, 2], [1, 3, 2])
+
+  def testBCast_4B(self):
+    self._testBCastB([1, 3, 2], [1, 3, 2])
+
+  def testBCast_4C(self):
+    self._testBCastC([1, 3, 2], [1, 3, 2])
+
+  def testBCast_4D(self):
+    self._testBCastD([1, 3, 2], [1, 3, 2])
+
+  def testBCast_5A(self):
+    self._testBCastA([1, 3, 2], [2, 3, 1])
+
+  def testBCast_5B(self):
+    self._testBCastB([1, 3, 2], [2, 3, 1])
+
+  def testBCast_5C(self):
+    self._testBCastC([1, 3, 2], [2, 3, 1])
+
+  def testBCast_5D(self):
+    self._testBCastD([1, 3, 2], [2, 3, 1])
+
+  def testBCast_6A(self):
+    self._testBCastA([1, 3, 2], [2, 1, 1])
+
+  def testBCast_6B(self):
+    self._testBCastB([1, 3, 2], [2, 1, 1])
+
+  def testBCast_6C(self):
+    self._testBCastC([1, 3, 2], [2, 1, 1])
+
+  def testBCast_6D(self):
+    self._testBCastD([1, 3, 2], [2, 1, 1])
+
+  def testBCast_7A(self):
+    self._testBCastA([1, 3, 2], [1, 3, 1])
+
+  def testBCast_7B(self):
+    self._testBCastB([1, 3, 2], [1, 3, 1])
+
+  def testBCast_7C(self):
+    self._testBCastC([1, 3, 2], [1, 3, 1])
+
+  def testBCast_7D(self):
+    self._testBCastD([1, 3, 2], [1, 3, 1])
+
+  def testBCast_8A(self):
+    self._testBCastA([2, 1, 5], [2, 3, 1])
+
+  def testBCast_8B(self):
+    self._testBCastB([2, 1, 5], [2, 3, 1])
+
+  def testBCast_8C(self):
+    self._testBCastC([2, 1, 5], [2, 3, 1])
+
+  def testBCast_8D(self):
+    self._testBCastD([2, 1, 5], [2, 3, 1])
+
+  def testBCast_9A(self):
+    self._testBCastA([2, 0, 5], [2, 0, 1])
+
+  def testBCast_9B(self):
+    self._testBCastB([2, 0, 5], [2, 0, 1])
+
+  def testBCast_9C(self):
+    self._testBCastC([2, 0, 5], [2, 0, 1])
+
+  def testBCast_9D(self):
+    self._testBCastD([2, 0, 5], [2, 0, 1])
+
+  def testBCast_10A(self):
+    self._testBCastA([2, 3, 0], [2, 3, 1])
+
+  def testBCast_10B(self):
+    self._testBCastB([2, 3, 0], [2, 3, 1])
+
+  def testBCast_10C(self):
+    self._testBCastC([2, 3, 0], [2, 3, 1])
+
+  def testBCast_10D(self):
+    self._testBCastD([2, 3, 0], [2, 3, 1])
+
+  def testBCast_11A(self):
+    self._testBCastA([1, 3, 2], [1, 3, 2])
+
+  def testBCast_11B(self):
+    self._testBCastB([1, 3, 2], [1, 3, 2])
+
+  def testBCast_11C(self):
+    self._testBCastC([1, 3, 2], [1, 3, 2])
+
+  def testBCast_11D(self):
+    self._testBCastD([1, 3, 2], [1, 3, 2])
+
+  def testBCast_12A(self):
+    self._testBCastA([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+  def testBCast_12B(self):
+    self._testBCastB([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+  def testBCast_12C(self):
+    self._testBCastC([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+  def testBCast_12D(self):
+    self._testBCastD([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+  def testBCast_13A(self):
+    self._testBCastA([1, 3, 2, 1, 1], [1])
+
+  def testBCast_13B(self):
+    self._testBCastB([1, 3, 2, 1, 1], [1])
+
+  def testBCast_13C(self):
+    self._testBCastC([1, 3, 2, 1, 1], [1])
+
+  def testBCast_13D(self):
+    self._testBCastD([1, 3, 2, 1, 1], [1])
+
+  def testBCast_14A(self):
+    self._testBCastA([2, 3, 1, 1, 5], [1])
+
+  def testBCast_14B(self):
+    self._testBCastB([2, 3, 1, 1, 5], [1])
+
+  def testBCast_14C(self):
+    self._testBCastC([2, 3, 1, 1, 5], [1])
+
+  def testBCast_14D(self):
+    self._testBCastD([2, 3, 1, 1, 5], [1])
+
+  def testBCast_15A(self):
+    self._testBCastA([10, 3, 1, 2], [3, 1, 2])
+
+  def testBCast_15B(self):
+    self._testBCastB([10, 3, 1, 2], [3, 1, 2])
+
+  def testBCast_15C(self):
+    self._testBCastC([10, 3, 1, 2], [3, 1, 2])
+
+  def testBCast_15D(self):
+    self._testBCastD([10, 3, 1, 2], [3, 1, 2])
+
+  def testMismatchedDimensions(self):
+    for func in [
+        math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.div, _ADD,
+        _SUB, _MUL, _TRUEDIV, _FLOORDIV
+    ]:
+      with self.assertRaisesWithPredicateMatch(
+          ValueError, lambda e: "Dimensions must" in str(e)):
+        func(
+            ops.convert_to_tensor([10.0, 20.0, 30.0]),
+            ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
+
+  def testZeroPowGrad(self):
+    with self.cached_session():
+      for dtype in (np.float16, np.float32, np.float64, np.complex64,
+                    np.complex128):
+        x = constant_op.constant(0.0, dtype=dtype)
+        y = constant_op.constant(2.0, dtype=dtype)
+        z = math_ops.pow(x, y)
+        error = gradient_checker.compute_gradient_error(y, [], z, [])
+        self.assertEqual(error, 0)
+
+  def testComplexPowGrad(self):
+    with self.cached_session():
+      for dtype in np.complex64, np.complex128:
+        for base in 2.0, -2.0:
+          x = constant_op.constant(base, dtype=dtype)
+          y = constant_op.constant(2.0, dtype=dtype)
+          z = math_ops.pow(x, y)
+          error = gradient_checker.compute_gradient_error(y, [], z, [])
+          self.assertLess(error, 2e-4)
+
+  def testAtan2SpecialValues(self):
+    x1l, x2l = zip((+0.0, +0.0), (+0.0, -0.0), (-0.0, +0.0), (-0.0, -0.0),
+                   (1.2345, float("inf")), (1.2345, -float("inf")),
+                   (-4.321, float("inf")), (-4.125, -float("inf")),
+                   (float("inf"), float("inf")), (float("inf"), -float("inf")),
+                   (-float("inf"), float("inf")),
+                   (-float("inf"), -float("inf")))
+    for dtype in np.float32, np.float64:
+      x1 = np.array(x1l).astype(dtype)
+      x2 = np.array(x2l).astype(dtype)
+      self._compareCpu(x1, x2, np.arctan2, math_ops.atan2)
+      self._compareGpu(x1, x2, np.arctan2, math_ops.atan2)
+
+  def testPowNegativeExponent(self):
+    for dtype in [np.int32, np.int64]:
+      with self.test_session(use_gpu=False) as sess:
+        with self.assertRaisesRegexp(
+            errors_impl.InvalidArgumentError,
+            "Integers to negative integer powers are not allowed"):
+          x = np.array([5, 2]).astype(dtype)
+          y = np.array([-2, 3]).astype(dtype)
+          sess.run(math_ops.pow(x, y))
+
+      with self.test_session(use_gpu=False) as sess:
+        with self.assertRaisesRegexp(
+            errors_impl.InvalidArgumentError,
+            "Integers to negative integer powers are not allowed"):
+          x = np.array([5, 2]).astype(dtype)
+          y = np.array([2, -3]).astype(dtype)
+          sess.run(math_ops.pow(x, y))
+
+      with self.test_session(use_gpu=False) as sess:
+        with self.assertRaisesRegexp(
+            errors_impl.InvalidArgumentError,
+            "Integers to negative integer powers are not allowed"):
+          x = np.array([5, 2]).astype(dtype)
+          y = -3
+          sess.run(math_ops.pow(x, y))
+
+
+class ComparisonOpTest(test.TestCase):
+
+  def _compareScalar(self, func, x, y, dtype):
+    with self.test_session(force_gpu=test_util.is_gpu_available()):
+      out = func(
+          ops.convert_to_tensor(np.array([x]).astype(dtype)),
+          ops.convert_to_tensor(np.array([y]).astype(dtype)))
+      ret = out.eval()
+    return ret[0]
+
+  def testScalarCompareScalar(self):
+    dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
+    data = [-1, 0, 1]
+    for t in dtypes:
+      for x in data:
+        for y in data:
+          self.assertEqual(self._compareScalar(math_ops.less, x, y, t), x < y)
+          self.assertEqual(
+              self._compareScalar(math_ops.less_equal, x, y, t), x <= y)
+          self.assertEqual(
+              self._compareScalar(math_ops.greater, x, y, t), x > y)
+          self.assertEqual(
+              self._compareScalar(math_ops.greater_equal, x, y, t), x >= y)
+          self.assertEqual(self._compareScalar(math_ops.equal, x, y, t), x == y)
+          self.assertEqual(
+              self._compareScalar(math_ops.not_equal, x, y, t), x != y)
+    data = [-1, 0, 1, -1j, 1j, 1 + 1j, 1 - 1j]
+    for t in [np.complex64, np.complex128]:
+      for x in data:
+        for y in data:
+          self.assertEqual(self._compareScalar(math_ops.equal, x, y, t), x == y)
+          self.assertEqual(
+              self._compareScalar(math_ops.not_equal, x, y, t), x != y)
+
+  def _compare(self, x, y, np_func, tf_func):
+    np_ans = np_func(x, y)
+    with self.test_session(force_gpu=test_util.is_gpu_available()):
+      out = tf_func(ops.convert_to_tensor(x), ops.convert_to_tensor(y))
+      tf_ans = out.eval()
+    self.assertAllEqual(np_ans, tf_ans)
+
+  def testTensorCompareTensor(self):
+    x = np.linspace(-15, 15, 6).reshape(1, 3, 2)
+    y = np.linspace(20, -10, 6).reshape(1, 3, 2)
+    for t in [np.float16, np.float32, np.float64, np.int32, np.int64]:
+      xt = x.astype(t)
+      yt = y.astype(t)
+      self._compare(xt, yt, np.less, math_ops.less)
+      self._compare(xt, yt, np.less_equal, math_ops.less_equal)
+      self._compare(xt, yt, np.greater, math_ops.greater)
+      self._compare(xt, yt, np.greater_equal, math_ops.greater_equal)
+      self._compare(xt, yt, np.equal, math_ops.equal)
+      self._compare(xt, yt, np.not_equal, math_ops.not_equal)
+    # Complex types do not support ordering but do support equality tests.
+    for t in [np.complex64, np.complex128]:
+      xt = x.astype(t)
+      xt -= 1j * xt
+      yt = y.astype(t)
+      yt -= 1j * yt
+      self._compare(xt, yt, np.equal, math_ops.equal)
+      self._compare(xt, yt, np.not_equal, math_ops.not_equal)
+
+  def _compareBCast(self, xs, ys, dtype, np_func, tf_func):
+    x = np.linspace(-15, 15, np.prod(xs)).astype(dtype).reshape(xs)
+    y = np.linspace(20, -10, np.prod(ys)).astype(dtype).reshape(ys)
+    if dtype in (np.complex64, np.complex128):
+      x -= 1j * x
+      y -= 1j * y
+    self._compare(x, y, np_func, tf_func)
+    self._compare(y, x, np_func, tf_func)
+
+  def _testBCastByFunc(self, np_func, tf_func, include_complex=False):
+    shapes = [
+        ([1, 3, 2], [1]),
+        ([1, 3, 2], [2]),
+        ([1, 3, 2], [3, 2]),
+        ([1, 3, 2], [3, 1]),
+        ([1, 3, 2], [1, 3, 2]),
+        ([1, 3, 2], [2, 3, 1]),
+        ([1, 3, 2], [2, 1, 1]),
+        ([1, 3, 2], [1, 3, 1]),
+        ([2, 1, 5], [2, 3, 1]),
+        ([2, 0, 5], [2, 0, 1]),
+        ([2, 3, 0], [2, 3, 1]),
+    ]
+    dtypes = [
+        np.float16,
+        np.float32,
+        np.float64,
+        np.int32,
+        np.int64,
+    ]
+    if include_complex:
+      dtypes.extend([np.complex64, np.complex128])
+
+    for (xs, ys) in shapes:
+      for dtype in dtypes:
+        self._compareBCast(xs, ys, dtype, np_func, tf_func)
+
+  def testBCastLess(self):
+    self._testBCastByFunc(np.less, math_ops.less)
+
+  def testBCastLessEqual(self):
+    self._testBCastByFunc(np.less_equal, math_ops.less_equal)
+
+  def testBCastGreater(self):
+    self._testBCastByFunc(np.greater, math_ops.greater)
+
+  def testBCastGreaterEqual(self):
+    self._testBCastByFunc(np.greater_equal, math_ops.greater_equal)
+
+  def testBCastEqual(self):
+    self._testBCastByFunc(np.equal, math_ops.equal, include_complex=True)
+
+  def testBCastNotEqual(self):
+    self._testBCastByFunc(
+        np.not_equal, math_ops.not_equal, include_complex=True)
+
+  def testShapeMismatch(self):
+    dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
+    funcs = [
+        math_ops.less, math_ops.less_equal, math_ops.greater,
+        math_ops.greater_equal, math_ops.equal, math_ops.not_equal
+    ]
+    x = np.arange(0, 10).reshape([2, 5])
+    y = np.arange(0, 10).reshape([5, 2])
+    for t in dtypes:
+      for f in funcs:
+        with self.assertRaisesWithPredicateMatch(
+            ValueError, lambda e: "Dimensions must" in str(e)):
+          f(x.astype(t), y.astype(t))
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 00d7f95..c5311ad 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -18,25 +18,19 @@
 from __future__ import division
 from __future__ import print_function
 
-import math
-
 import numpy as np
 
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes as dtypes_lib
-from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_math_ops
 from tensorflow.python.ops import gradient_checker
-from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
 
 _ADD = lambda x, y: x + y
 _SUB = lambda x, y: x - y
@@ -45,8 +39,6 @@
 _TRUEDIV = lambda x, y: x / y
 _FLOORDIV = lambda x, y: x // y
 _MOD = lambda x, y: x % y
-_NEG = lambda x: -x
-_ABS = abs
 
 _LT = lambda x, y: x < y
 _LE = lambda x, y: x <= y
@@ -74,8 +66,11 @@
 
 
 def _default_tolerance(dtype):
-  """Returns a sensible default tolerance for comparing results of a given
-  type"""
+  """Returns a sensible default tolerance for comparing results of a given type.
+
+  Args:
+    dtype: A datatype.
+  """
   if dtype == np.float16:
     return 5e-3
   elif dtype in (np.float32, np.complex64):
@@ -86,1147 +81,6 @@
     return None  # Fail fast for unexpected types
 
 
-class UnaryOpTest(test.TestCase):
-
-  def _compareCpu(self, x, np_func, tf_func, grad_rtol=None, grad_atol=None):
-    if grad_rtol is None:
-      grad_rtol = _default_tolerance(x.dtype)
-    if grad_atol is None:
-      grad_atol = _default_tolerance(x.dtype)
-    np_ans = np_func(x)
-    with self.test_session(use_gpu=False):
-      inx = ops.convert_to_tensor(x)
-      if x.dtype in (np.float32, np.float64,
-                     dtypes_lib.bfloat16.as_numpy_dtype):
-        y = 1.1 * tf_func(inx)
-        np_ans *= 1.1
-      else:
-        y = tf_func(inx)
-      tf_cpu = y.eval()
-      self.assertShapeEqual(np_ans, y)
-      if x.dtype == np.float16:
-        self.assertAllClose(np_ans, tf_cpu, rtol=1e-3, atol=1e-3)
-      elif x.dtype == dtypes_lib.bfloat16.as_numpy_dtype:
-        self.assertAllClose(np_ans, tf_cpu, rtol=1e-2, atol=1e-2)
-      else:
-        self.assertAllClose(np_ans, tf_cpu)
-
-      if x.dtype in (np.complex64, np.complex128) and tf_func == math_ops.sign:
-        return  # Return early
-
-      if x.dtype == np.float16:
-        s = list(np.shape(x))
-        jacob_t, _ = gradient_checker.compute_gradient(
-            inx, s, y, s, x_init_value=x)
-        xf = x.astype(np.float)
-        inxf = ops.convert_to_tensor(xf)
-        yf = tf_func(inxf)
-        _, jacob_n = gradient_checker.compute_gradient(
-            inxf, s, yf, s, x_init_value=xf, delta=1e-2)
-        jacob_n = jacob_n.astype(np.float16)
-        self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
-      elif x.dtype in (np.float32, np.complex64):
-        s = list(np.shape(x))
-        jacob_t, jacob_n = gradient_checker.compute_gradient(
-            inx, s, y, s, x_init_value=x, delta=1e-3)
-        self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
-      elif x.dtype in (np.float64, np.complex128):
-        s = list(np.shape(x))
-        jacob_t, jacob_n = gradient_checker.compute_gradient(
-            inx, s, y, s, x_init_value=x, delta=1e-5)
-        self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
-
-  def _check(self, result_tensor, result_np, input_sp_t, tol):
-    self.assertTrue(isinstance(result_tensor, sparse_tensor.SparseTensor))
-    self.assertTrue(isinstance(input_sp_t, sparse_tensor.SparseTensor))
-    self.assertAllEqual(input_sp_t.indices.eval(), result_tensor.indices.eval())
-    self.assertAllEqual(input_sp_t.dense_shape.eval(),
-                        result_tensor.dense_shape.eval())
-    if tol is None:
-      self.assertAllClose(result_np, result_tensor.values.eval())
-    else:
-      self.assertAllClose(
-          result_np, result_tensor.values.eval(), rtol=tol, atol=tol)
-
-  def _compareSparseCpu(self, x, np_func, tf_func, tol):
-    x_sp, x_sp_vals = _sparsify(x)
-    res_np = np_func(x_sp_vals)
-    with self.test_session(use_gpu=False):
-      self._check(tf_func(x_sp), res_np, x_sp, tol)
-
-  def _compareGpu(self, x, np_func, tf_func):
-    np_ans = np_func(x)
-    with self.test_session(force_gpu=test_util.is_gpu_available()):
-      result = tf_func(ops.convert_to_tensor(x))
-      tf_gpu = result.eval()
-    if x.dtype == np.float16:
-      self.assertAllClose(np_ans, tf_gpu, rtol=1e-3, atol=1e-3)
-    else:
-      self.assertAllClose(np_ans, tf_gpu)
-    # TODO(zhifengc/ke): make gradient checker work on GPU.
-
-  def _compareSparseGpu(self, x, np_func, tf_func, tol):
-    x_sp, x_sp_vals = _sparsify(x)
-    res_np = np_func(x_sp_vals)
-    with self.test_session(force_gpu=test_util.is_gpu_available()):
-      self._check(tf_func(x_sp), res_np, x_sp, tol)
-
-  def _compareBoth(self, x, np_func, tf_func):
-    self._compareCpu(x, np_func, tf_func)
-    self._compareGpu(x, np_func, tf_func)
-
-  def _compareBothSparse(self, x, np_func, tf_func, tol=None):
-    self._compareSparseCpu(x, np_func, tf_func, tol)
-    self._compareSparseGpu(x, np_func, tf_func, tol)
-
-  def _inv(self, x):
-    return 1.0 / x
-
-  def _rsqrt(self, x):
-    return self._inv(np.sqrt(x))
-
-  def _sigmoid(self, x):
-    return 1.0 / (1.0 + np.exp(-x))
-
-  def _log_sigmoid(self, x):
-    return np.log(self._sigmoid(x))
-
-  def _replace_domain_error_with_inf(self, fn):
-
-    def func(x):
-      try:
-        return fn(x)
-      except ValueError as e:
-        if "domain error" in str(e):
-          return np.inf * np.ones_like(x)
-        else:
-          raise e
-
-    return func
-
-  def testFloatBasic(self):
-    x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
-    w = x - x.min() + 1.02  # all greater than 1
-    y = (x + .5).astype(np.float32)  # no zero
-    z = (x + 15.5).astype(np.float32)  # all positive
-    k = np.arange(-0.90, 0.90, 0.25).astype(np.float32)  # between -1 and 1
-
-    self._compareBoth(x, np.abs, math_ops.abs)
-    self._compareBoth(x, np.abs, _ABS)
-    self._compareBoth(x, np.negative, math_ops.negative)
-    self._compareBoth(x, np.negative, _NEG)
-    self._compareBoth(y, self._inv, math_ops.reciprocal)
-    self._compareBoth(x, np.square, math_ops.square)
-    self._compareBoth(z, np.sqrt, math_ops.sqrt)
-    self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
-    self._compareBoth(x, np.exp, math_ops.exp)
-    self._compareBoth(x, np.expm1, math_ops.expm1)
-    self._compareBoth(z, np.log, math_ops.log)
-    self._compareBoth(z, np.log1p, math_ops.log1p)
-    self._compareBoth(x, np.sinh, math_ops.sinh)
-    self._compareBoth(x, np.cosh, math_ops.cosh)
-    self._compareBoth(x, np.tanh, math_ops.tanh)
-    self._compareBoth(x, np.arcsinh, math_ops.asinh)
-    self._compareBoth(w, np.arccosh, math_ops.acosh)
-    self._compareBoth(k, np.arctanh, math_ops.atanh)
-    self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
-    self._compareBoth(x, self._log_sigmoid, math_ops.log_sigmoid)
-    self._compareBoth(y, np.sign, math_ops.sign)
-    self._compareBoth(x, np.sin, math_ops.sin)
-    self._compareBoth(x, np.cos, math_ops.cos)
-    self._compareBoth(k, np.arcsin, math_ops.asin)
-    self._compareBoth(k, np.arccos, math_ops.acos)
-    self._compareBoth(x, np.arctan, math_ops.atan)
-    self._compareBoth(x, np.tan, math_ops.tan)
-    self._compareBoth(y,
-                      np.vectorize(
-                          self._replace_domain_error_with_inf(math.lgamma)),
-                      math_ops.lgamma)
-    self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
-    self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
-    try:
-      from scipy import special  # pylint: disable=g-import-not-at-top
-      self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
-      self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
-    except ImportError as e:
-      tf_logging.warn("Cannot test special functions: %s" % str(e))
-
-    self._compareBothSparse(x, np.abs, math_ops.abs)
-    self._compareBothSparse(x, np.negative, math_ops.negative)
-    self._compareBothSparse(x, np.square, math_ops.square)
-    self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
-    self._compareBothSparse(x, np.tanh, math_ops.tanh)
-    self._compareBothSparse(y, np.sign, math_ops.sign)
-    self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
-
-  def testFloatTanhEdge(self):
-    x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
-    self._compareBoth(x, np.tanh, math_ops.tanh)
-    x = np.arange(-40, -40 + 6).reshape(6).astype(np.float32)
-    self._compareBoth(x, np.tanh, math_ops.tanh)
-
-  def testFloatEmpty(self):
-    x = np.empty((2, 0, 5), dtype=np.float32)
-    self._compareBoth(x, np.abs, math_ops.abs)
-    self._compareBoth(x, np.abs, _ABS)
-    self._compareBoth(x, np.negative, math_ops.negative)
-    self._compareBoth(x, np.negative, _NEG)
-    self._compareBoth(x, self._inv, math_ops.reciprocal)
-    self._compareBoth(x, np.square, math_ops.square)
-    self._compareBoth(x, np.sqrt, math_ops.sqrt)
-    self._compareBoth(x, self._rsqrt, math_ops.rsqrt)
-    self._compareBoth(x, np.exp, math_ops.exp)
-    self._compareBoth(x, np.expm1, math_ops.expm1)
-    self._compareBoth(x, np.log, math_ops.log)
-    self._compareBoth(x, np.log1p, math_ops.log1p)
-    self._compareBoth(x, np.sinh, math_ops.sinh)
-    self._compareBoth(x, np.arcsinh, math_ops.asinh)
-    self._compareBoth(x, np.cosh, math_ops.cosh)
-    self._compareBoth(x, np.tanh, math_ops.tanh)
-    self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
-    self._compareBoth(x, np.sign, math_ops.sign)
-    self._compareBoth(x, np.sin, math_ops.sin)
-    self._compareBoth(x, np.cos, math_ops.cos)
-    # Can't use vectorize below, so just use some arbitrary function
-    self._compareBoth(x, np.sign, math_ops.lgamma)
-    self._compareBoth(x, np.sign, math_ops.erf)
-    self._compareBoth(x, np.sign, math_ops.erfc)
-    self._compareBoth(x, np.tan, math_ops.tan)
-    self._compareBoth(x, np.arcsin, math_ops.asin)
-    self._compareBoth(x, np.arccos, math_ops.acos)
-    self._compareBoth(x, np.arctan, math_ops.atan)
-    try:
-      from scipy import special  # pylint: disable=g-import-not-at-top
-      self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
-      self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
-    except ImportError as e:
-      tf_logging.warn("Cannot test special functions: %s" % str(e))
-
-    self._compareBothSparse(x, np.abs, math_ops.abs)
-    self._compareBothSparse(x, np.negative, math_ops.negative)
-    self._compareBothSparse(x, np.square, math_ops.square)
-    self._compareBothSparse(x, np.sqrt, math_ops.sqrt, tol=1e-3)
-    self._compareBothSparse(x, np.tanh, math_ops.tanh)
-    self._compareBothSparse(x, np.sign, math_ops.sign)
-    self._compareBothSparse(x, np.sign, math_ops.erf)
-
-  def testDoubleBasic(self):
-    x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
-    w = x - x.min() + 1.02  # all greater than 1
-    y = (x + .5).astype(np.float64)  # no zero
-    z = (x + 15.5).astype(np.float64)  # all positive
-    k = np.arange(-0.90, 0.90,
-                  0.35).reshape(1, 3, 2).astype(np.float64)  # between -1 and 1
-    self._compareBoth(x, np.abs, math_ops.abs)
-    self._compareBoth(x, np.abs, _ABS)
-    self._compareBoth(x, np.negative, math_ops.negative)
-    self._compareBoth(x, np.negative, _NEG)
-    self._compareBoth(y, self._inv, math_ops.reciprocal)
-    self._compareBoth(x, np.square, math_ops.square)
-    self._compareBoth(z, np.sqrt, math_ops.sqrt)
-    self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
-    self._compareBoth(x, np.exp, math_ops.exp)
-    self._compareBoth(x, np.expm1, math_ops.expm1)
-    self._compareBoth(z, np.log, math_ops.log)
-    self._compareBoth(z, np.log1p, math_ops.log1p)
-    self._compareBoth(x, np.sinh, math_ops.sinh)
-    self._compareBoth(x, np.cosh, math_ops.cosh)
-    self._compareBoth(x, np.tanh, math_ops.tanh)
-    self._compareBoth(x, np.arcsinh, math_ops.asinh)
-    self._compareBoth(w, np.arccosh, math_ops.acosh)
-    self._compareBoth(k, np.arctanh, math_ops.atanh)
-    self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
-    self._compareBoth(y, np.sign, math_ops.sign)
-    self._compareBoth(x, np.sin, math_ops.sin)
-    self._compareBoth(x, np.cos, math_ops.cos)
-    self._compareBoth(y,
-                      np.vectorize(
-                          self._replace_domain_error_with_inf(math.lgamma)),
-                      math_ops.lgamma)
-    self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
-    self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
-    self._compareBoth(x, np.arctan, math_ops.atan)
-    self._compareBoth(k, np.arcsin, math_ops.asin)
-    self._compareBoth(k, np.arccos, math_ops.acos)
-    self._compareBoth(k, np.tan, math_ops.tan)
-    try:
-      from scipy import special  # pylint: disable=g-import-not-at-top
-      self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
-      self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
-    except ImportError as e:
-      tf_logging.warn("Cannot test special functions: %s" % str(e))
-
-    self._compareBothSparse(x, np.abs, math_ops.abs)
-    self._compareBothSparse(x, np.negative, math_ops.negative)
-    self._compareBothSparse(x, np.square, math_ops.square)
-    self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
-    self._compareBothSparse(x, np.tanh, math_ops.tanh)
-    self._compareBothSparse(y, np.sign, math_ops.sign)
-    self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
-
-  def testHalfBasic(self):
-    x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16)
-    y = (x + .5).astype(np.float16)  # no zero
-    z = (x + 15.5).astype(np.float16)  # all positive
-    self._compareBoth(x, np.abs, math_ops.abs)
-    self._compareBoth(x, np.abs, _ABS)
-    self._compareBoth(x, np.negative, math_ops.negative)
-    self._compareBoth(x, np.negative, _NEG)
-    self._compareBoth(y, self._inv, math_ops.reciprocal)
-    self._compareBoth(x, np.square, math_ops.square)
-    self._compareBoth(z, np.sqrt, math_ops.sqrt)
-    self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
-    self._compareBoth(x, np.exp, math_ops.exp)
-    self._compareBoth(x, np.expm1, math_ops.expm1)
-    self._compareBoth(z, np.log, math_ops.log)
-    self._compareBoth(z, np.log1p, math_ops.log1p)
-    self._compareBoth(x, np.tanh, math_ops.tanh)
-    self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
-    self._compareBoth(y, np.sign, math_ops.sign)
-    self._compareBoth(x, np.sin, math_ops.sin)
-    self._compareBoth(x, np.cos, math_ops.cos)
-    self._compareBoth(y,
-                      np.vectorize(
-                          self._replace_domain_error_with_inf(math.lgamma)),
-                      math_ops.lgamma)
-    self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
-    self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
-    try:
-      from scipy import special  # pylint: disable=g-import-not-at-top
-      self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
-      self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
-    except ImportError as e:
-      tf_logging.warn("Cannot test special functions: %s" % str(e))
-
-    self._compareBothSparse(x, np.abs, math_ops.abs)
-    self._compareBothSparse(x, np.negative, math_ops.negative)
-    self._compareBothSparse(x, np.square, math_ops.square)
-    self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
-    self._compareBothSparse(x, np.tanh, math_ops.tanh)
-    self._compareBothSparse(y, np.sign, math_ops.sign)
-    self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf, tol=1e-3)
-
-  def testInt32Basic(self):
-    x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
-    self._compareCpu(x, np.abs, math_ops.abs)
-    self._compareCpu(x, np.abs, _ABS)
-    self._compareBoth(x, np.negative, math_ops.negative)
-    self._compareBoth(x, np.negative, _NEG)
-    self._compareBoth(x, np.square, math_ops.square)
-    self._compareCpu(x, np.sign, math_ops.sign)
-
-    self._compareBothSparse(x, np.abs, math_ops.abs)
-    self._compareBothSparse(x, np.negative, math_ops.negative)
-    self._compareBothSparse(x, np.square, math_ops.square)
-    self._compareBothSparse(x, np.sign, math_ops.sign)
-
-  def testInt64Basic(self):
-    x = np.arange(-6 << 40, 6 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
-    self._compareCpu(x, np.abs, math_ops.abs)
-    self._compareCpu(x, np.abs, _ABS)
-    self._compareCpu(x, np.negative, math_ops.negative)
-    self._compareCpu(x, np.negative, _NEG)
-    self._compareCpu(x, np.sign, math_ops.sign)
-
-    self._compareBothSparse(x, np.abs, math_ops.abs)
-    self._compareBothSparse(x, np.negative, math_ops.negative)
-    self._compareBothSparse(x, np.sign, math_ops.sign)
-
-  def testInt64Square(self):
-    x = np.arange(-6 << 20, 6 << 20, 2 << 20).reshape(1, 3, 2).astype(np.int64)
-    self._compareCpu(x, np.square, math_ops.square)
-    self._compareBothSparse(x, np.square, math_ops.square)
-
-  def testComplex64Basic(self):
-    x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
-        np.complex64)
-    y = x + np.complex(0.5, 0.5)  # no zeros
-    self._compareBoth(x, np.abs, math_ops.abs)
-    self._compareBoth(x, np.abs, _ABS)
-    self._compareBoth(x, np.negative, math_ops.negative)
-    self._compareBoth(x, np.negative, _NEG)
-    self._compareCpu(y, self._inv, math_ops.reciprocal)
-    self._compareCpu(x, np.square, math_ops.square)
-    self._compareCpu(y, np.sqrt, math_ops.sqrt)
-    self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
-    self._compareBoth(x, np.exp, math_ops.exp)
-    self._compareCpu(x, np.expm1, math_ops.expm1)
-    self._compareCpu(y, np.log, math_ops.log)
-    self._compareCpu(y, np.log1p, math_ops.log1p)
-    self._compareCpu(x, np.sinh, math_ops.sinh)
-    self._compareCpu(x, np.cosh, math_ops.cosh)
-    self._compareCpu(x, np.tanh, math_ops.tanh)
-
-    # Complex64 versions of asinh() and acosh() in libstdc++ only have 6 digits
-    # of precision.
-    # Small gradient values + low precision --> High relative error
-    self._compareCpu(y, np.arcsinh, math_ops.asinh, grad_rtol=1e-2)
-    self._compareCpu(y, np.arccosh, math_ops.acosh, grad_rtol=1e-2)
-
-    self._compareCpu(y, np.arctanh, math_ops.atanh)
-    self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
-    self._compareCpu(x, np.sin, math_ops.sin)
-    self._compareCpu(x, np.cos, math_ops.cos)
-
-    self._compareBothSparse(x, np.abs, math_ops.abs)
-    self._compareBothSparse(x, np.negative, math_ops.negative)
-    self._compareBothSparse(x, np.square, math_ops.square)
-    self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
-    self._compareBothSparse(x, np.tanh, math_ops.tanh)
-
-    # Numpy uses an incorrect definition of sign; use the right one instead.
-    def complex_sign(x):
-      return x / np.abs(x)
-
-    self._compareBoth(y, complex_sign, math_ops.sign)
-    self._compareBothSparse(y, complex_sign, math_ops.sign)
-
-  def testComplex128Basic(self):
-    x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
-        np.complex128)
-    y = x + np.complex(0.5, 0.5)  # no zeros
-    self._compareBoth(x, np.abs, math_ops.abs)
-    self._compareBoth(x, np.abs, _ABS)
-    self._compareBoth(x, np.negative, math_ops.negative)
-    self._compareBoth(x, np.negative, _NEG)
-    self._compareCpu(y, self._inv, math_ops.reciprocal)
-    self._compareCpu(x, np.square, math_ops.square)
-    self._compareCpu(y, np.sqrt, math_ops.sqrt)
-    self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
-    self._compareBoth(x, np.exp, math_ops.exp)
-    self._compareCpu(x, np.expm1, math_ops.expm1)
-    self._compareCpu(y, np.log, math_ops.log)
-    self._compareCpu(y, np.log1p, math_ops.log1p)
-    self._compareCpu(x, np.sinh, math_ops.sinh)
-    self._compareCpu(x, np.cosh, math_ops.cosh)
-    self._compareCpu(x, np.tanh, math_ops.tanh)
-    self._compareCpu(y, np.arcsinh, math_ops.asinh)
-    self._compareCpu(y, np.arccosh, math_ops.acosh)
-    self._compareCpu(y, np.arctanh, math_ops.atanh)
-    self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
-    self._compareCpu(x, np.sin, math_ops.sin)
-    self._compareCpu(x, np.cos, math_ops.cos)
-
-    self._compareBothSparse(x, np.abs, math_ops.abs)
-    self._compareBothSparse(x, np.negative, math_ops.negative)
-    self._compareBothSparse(x, np.square, math_ops.square)
-    self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
-    self._compareBothSparse(x, np.tanh, math_ops.tanh)
-
-    # Numpy uses an incorrect definition of sign; use the right one instead.
-    def complex_sign(x):
-      return x / np.abs(x)
-
-    self._compareBoth(y, complex_sign, math_ops.sign)
-    self._compareBothSparse(y, complex_sign, math_ops.sign)
-
-  def testGradGrad(self):
-    np.random.seed(7)
-    shape = (5,)
-    dtype_tols = [(np.float32, 5e-4), (np.float64, 1e-6), (np.complex64, 5e-4),
-                  (np.complex128, 1e-6)]
-    op_range = [
-        (gen_math_ops.reciprocal_grad, [-2, 2]),
-        (gen_math_ops.rsqrt_grad, [0.1, 3]),
-        (gen_math_ops.sigmoid_grad, [-2, 2]),
-        (gen_math_ops.sqrt_grad, [0.1, 3]),
-        (gen_math_ops.tanh_grad, [-2, 2]),
-    ]
-
-    def rand(dtype):
-      x = np.random.uniform(
-          real_range[0], real_range[1], size=shape[0]).astype(dtype)
-      if dtype in (np.complex64, np.complex128):
-        x += 1j * np.random.uniform(-2, 2, size=shape[0]).astype(dtype)
-      return x
-
-    for op, real_range in op_range:
-      with self.cached_session():
-        for dtype, tol in dtype_tols:
-          x = constant_op.constant(rand(dtype))
-          y = constant_op.constant(rand(dtype))
-          z = op(x, y)
-          grads = gradient_checker.compute_gradient(
-              [x, y], [shape, shape],
-              z,
-              shape,
-              x_init_value=[rand(dtype), rand(dtype)])
-          if isinstance(grads, tuple):
-            grads = [grads]
-          for analytical, numerical in grads:
-            self.assertAllClose(analytical, numerical, rtol=tol, atol=tol)
-
-
-class BinaryOpTest(test.TestCase):
-
-  def _compareCpu(self, x, y, np_func, tf_func, also_compare_variables=False):
-    np_ans = np_func(x, y)
-    with self.test_session(use_gpu=False):
-      inx = ops.convert_to_tensor(x)
-      iny = ops.convert_to_tensor(y)
-      out = tf_func(inx, iny)
-      tf_cpu = out.eval()
-      # Test that the op takes precedence over numpy operators.
-      np_left = tf_func(x, iny).eval()
-      np_right = tf_func(inx, y).eval()
-
-      if also_compare_variables:
-        var_x = variables.Variable(x)
-        var_y = variables.Variable(y)
-        variables.global_variables_initializer().run()
-        print(type(x), type(y), type(var_x), type(var_y))
-        print(type(tf_func(x, var_y)), type(tf_func(var_x, y)))
-        np_var_left = tf_func(x, var_y).eval()
-        np_var_right = tf_func(var_x, y).eval()
-
-    if np_ans.dtype != np.object:
-      self.assertAllClose(np_ans, tf_cpu)
-      self.assertAllClose(np_ans, np_left)
-      self.assertAllClose(np_ans, np_right)
-      if also_compare_variables:
-        self.assertAllClose(np_ans, np_var_left)
-        self.assertAllClose(np_ans, np_var_right)
-    self.assertShapeEqual(np_ans, out)
-
-  _GRAD_TOL = {
-      dtypes_lib.float16: 1e-3,
-      dtypes_lib.float32: 1e-3,
-      dtypes_lib.complex64: 1e-2,
-      dtypes_lib.float64: 1e-5,
-      dtypes_lib.complex128: 1e-4
-  }
-
-  def _compareGradientX(self,
-                        x,
-                        y,
-                        np_func,
-                        tf_func,
-                        numeric_gradient_type=None):
-    z = np_func(x, y)
-    zs = list(z.shape)
-    with self.cached_session():
-      inx = ops.convert_to_tensor(x)
-      iny = ops.convert_to_tensor(y)
-      if x.dtype in (np.float32, np.float64):
-        out = 1.1 * tf_func(inx, iny)
-      else:
-        out = tf_func(inx, iny)
-      xs = list(x.shape)
-      jacob_t, jacob_n = gradient_checker.compute_gradient(
-          inx, xs, out, zs, x_init_value=x)
-      if numeric_gradient_type is not None:
-        xf = x.astype(numeric_gradient_type)
-        yf = y.astype(numeric_gradient_type)
-        inxf = ops.convert_to_tensor(xf)
-        inyf = ops.convert_to_tensor(yf)
-        outf = tf_func(inxf, inyf)
-        _, jacob_n = gradient_checker.compute_gradient(
-            inxf, xs, outf, zs, x_init_value=xf, delta=1e-3)
-        jacob_n = jacob_n.astype(x.dtype)
-      tol = self._GRAD_TOL[dtypes_lib.as_dtype(x.dtype)]
-      self.assertAllClose(jacob_t, jacob_n, rtol=tol, atol=tol)
-
-  def _compareGradientY(self,
-                        x,
-                        y,
-                        np_func,
-                        tf_func,
-                        numeric_gradient_type=None):
-    z = np_func(x, y)
-    zs = list(z.shape)
-    with self.cached_session():
-      inx = ops.convert_to_tensor(x)
-      iny = ops.convert_to_tensor(y)
-      if x.dtype in (np.float32, np.float64):
-        out = 1.1 * tf_func(inx, iny)
-      else:
-        out = tf_func(inx, iny)
-      ys = list(np.shape(y))
-      jacob_t, jacob_n = gradient_checker.compute_gradient(
-          iny, ys, out, zs, x_init_value=y)
-      if numeric_gradient_type is not None:
-        xf = x.astype(numeric_gradient_type)
-        yf = y.astype(numeric_gradient_type)
-        inxf = ops.convert_to_tensor(xf)
-        inyf = ops.convert_to_tensor(yf)
-        outf = tf_func(inxf, inyf)
-        _, jacob_n = gradient_checker.compute_gradient(
-            inyf, ys, outf, zs, x_init_value=yf)
-        jacob_n = jacob_n.astype(x.dtype)
-    tol = self._GRAD_TOL[dtypes_lib.as_dtype(x.dtype)]
-    self.assertAllClose(jacob_t, jacob_n, rtol=tol, atol=tol)
-
-  def _compareGpu(self, x, y, np_func, tf_func):
-    np_ans = np_func(x, y)
-    with self.test_session(force_gpu=test_util.is_gpu_available()):
-      inx = ops.convert_to_tensor(x)
-      iny = ops.convert_to_tensor(y)
-      out = tf_func(inx, iny)
-      tf_gpu = out.eval()
-    self.assertAllClose(np_ans, tf_gpu)
-    self.assertShapeEqual(np_ans, out)
-    # TODO(zhifengc/ke): make gradient checker work on GPU.
-
-  def _compareBoth(self, x, y, np_func, tf_func, also_compare_variables=False):
-    self._compareCpu(x, y, np_func, tf_func, also_compare_variables)
-    if x.dtype in (np.float16, np.float32, np.float64, np.complex64,
-                   np.complex128):
-      if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.zeta,
-                         math_ops.polygamma):
-        self._compareGradientX(x, y, np_func, tf_func)
-        self._compareGradientY(x, y, np_func, tf_func)
-      if tf_func in (math_ops.zeta, math_ops.polygamma):
-        # These methods only support gradients in the second parameter
-        self._compareGradientY(x, y, np_func, tf_func)
-      self._compareGpu(x, y, np_func, tf_func)
-
-  def testFloatBasic(self):
-    x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float32)
-    y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float32)
-    self._compareBoth(x, y, np.add, math_ops.add, also_compare_variables=True)
-    self._compareBoth(x, y, np.subtract, math_ops.subtract)
-    self._compareBoth(x, y, np.multiply, math_ops.multiply)
-    self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
-    self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
-    self._compareBoth(x, y, np.add, _ADD)
-    self._compareBoth(x, y, np.subtract, _SUB)
-    self._compareBoth(x, y, np.multiply, _MUL)
-    self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
-    self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
-    self._compareBoth(x, y, np.arctan2, math_ops.atan2)
-    x1 = np.random.randn(5, 6).astype(np.float32)
-    x2 = np.random.randn(5, 6).astype(np.float32)
-    # Remove tiny values--atan2 gradients are flaky near the origin.
-    x1[np.abs(x1) < 0.05] = 0.05 * np.sign(x1[np.abs(x1) < 0.05])
-    x2[np.abs(x2) < 0.05] = 0.05 * np.sign(x2[np.abs(x2) < 0.05])
-    self._compareBoth(x1, x2, np.arctan2, math_ops.atan2)
-    try:
-      from scipy import special  # pylint: disable=g-import-not-at-top
-      a_pos_small = np.linspace(0.1, 2, 15).reshape(1, 3, 5).astype(np.float32)
-      x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32)
-      self._compareBoth(a_pos_small, x_pos_small, special.gammainc,
-                        math_ops.igamma)
-      self._compareBoth(a_pos_small, x_pos_small, special.gammaincc,
-                        math_ops.igammac)
-      # Need x > 1
-      self._compareBoth(x_pos_small + 1, a_pos_small, special.zeta,
-                        math_ops.zeta)
-      n_small = np.arange(0, 15).reshape(1, 3, 5).astype(np.float32)
-      self._compareBoth(n_small, x_pos_small, special.polygamma,
-                        math_ops.polygamma)
-    except ImportError as e:
-      tf_logging.warn("Cannot test special functions: %s" % str(e))
-
-  def testFloatDifferentShapes(self):
-    x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.float32)
-    y = np.array([1, 2]).reshape(2, 1).astype(np.float32)
-    with self.cached_session() as sess:
-      inx = ops.convert_to_tensor(x)
-      iny = ops.convert_to_tensor(y)
-      s = math_ops.reduce_sum(inx * iny)
-      gx, gy = sess.run(gradients_impl.gradients(s, [inx, iny]))
-    # gx is simply the broadcasted y
-    self.assertAllEqual(gx,
-                        np.array([1, 1, 2, 2]).reshape(2, 2).astype(np.float32))
-    # gy is x's column summed up
-    self.assertAllEqual(gy, np.array([3, 7]).reshape(2, 1).astype(np.float32))
-
-  def testFloatVariableOverload(self):
-    x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.int32)
-    y = np.array([1, 2]).reshape(2, 1).astype(np.int32)
-    var_x = variables.Variable(x)
-    var_y = variables.Variable(y)
-    with self.cached_session() as sess:
-      sess.run([var_x.initializer, var_y.initializer])
-      left_result = (var_x * y).eval()
-      right_result = (x * var_y).eval()
-    np_result = x * y
-    self.assertAllEqual(np_result, left_result)
-    self.assertAllEqual(np_result, right_result)
-
-  def testDoubleBasic(self):
-    x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float64)
-    y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float64)
-    self._compareBoth(x, y, np.add, math_ops.add)
-    self._compareBoth(x, y, np.subtract, math_ops.subtract)
-    self._compareBoth(x, y, np.multiply, math_ops.multiply)
-    self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
-    self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
-    self._compareBoth(x, y, np.add, _ADD)
-    self._compareBoth(x, y, np.subtract, _SUB)
-    self._compareBoth(x, y, np.multiply, _MUL)
-    self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
-    self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
-    self._compareBoth(x, y, np.arctan2, math_ops.atan2)
-    x1 = np.random.randn(7, 4).astype(np.float64)
-    x2 = np.random.randn(7, 4).astype(np.float64)
-    # Remove tiny values--atan2 gradients are flaky near the origin.
-    x1[np.abs(x1) < 0.5] = 0.5 * np.sign(x1[np.abs(x1) < 0.5])
-    x2[np.abs(x2) < 0.5] = 0.5 * np.sign(x2[np.abs(x2) < 0.5])
-    self._compareBoth(x1, x2, np.arctan2, math_ops.atan2)
-    try:
-      from scipy import special  # pylint: disable=g-import-not-at-top
-      a_pos_small = np.linspace(0.1, 2, 15).reshape(1, 3, 5).astype(np.float32)
-      x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32)
-      self._compareBoth(a_pos_small, x_pos_small, special.gammainc,
-                        math_ops.igamma)
-      self._compareBoth(a_pos_small, x_pos_small, special.gammaincc,
-                        math_ops.igammac)
-    except ImportError as e:
-      tf_logging.warn("Cannot test special functions: %s" % str(e))
-
-  def testUint8Basic(self):
-    x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint8)
-    y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint8)
-    self._compareBoth(x, y, np.add, math_ops.add)
-
-  def testInt8Basic(self):
-    x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8)
-    y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int8)
-    self._compareBoth(x, y, np.multiply, math_ops.multiply)
-    self._compareBoth(x, y, np.multiply, _MUL)
-
-  def testInt16Basic(self):
-    x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int16)
-    y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int16)
-    self._compareBoth(x, y, np.multiply, math_ops.multiply)
-    self._compareBoth(x, y, np.multiply, _MUL)
-
-  def testUint16Basic(self):
-    x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint16)
-    y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint16)
-    self._compareBoth(x, y, np.multiply, math_ops.multiply)
-    self._compareBoth(x, y, np.multiply, _MUL)
-    self._compareBoth(x, y, np.true_divide, math_ops.truediv)
-    self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
-    self._compareBoth(x, y, np.true_divide, _TRUEDIV)
-    self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
-
-  def testInt32Basic(self):
-    x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int32)
-    y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int32)
-    self._compareBoth(x, y, np.add, math_ops.add)
-    self._compareBoth(x, y, np.subtract, math_ops.subtract)
-    self._compareBoth(x, y, np.multiply, math_ops.multiply)
-    self._compareBoth(x, y, np.true_divide, math_ops.truediv)
-    self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
-    self._compareBoth(x, y, np.mod, math_ops.mod)
-    self._compareBoth(x, y, np.add, _ADD)
-    self._compareBoth(x, y, np.subtract, _SUB)
-    self._compareBoth(x, y, np.multiply, _MUL)
-    self._compareBoth(x, y, np.true_divide, _TRUEDIV)
-    self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
-    self._compareBoth(x, y, np.mod, _MOD)
-    # _compareBoth tests on GPU only for floating point types, so test
-    # _MOD for int32 on GPU by calling _compareGpu
-    self._compareGpu(x, y, np.mod, _MOD)
-
-  def testInt64Basic(self):
-    x = np.arange(1 << 40, 13 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
-    y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int64)
-    self._compareBoth(x, y, np.subtract, math_ops.subtract)
-    self._compareBoth(x, y, np.multiply, math_ops.multiply)
-    self._compareBoth(x, y, np.true_divide, math_ops.truediv)
-    self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
-    self._compareBoth(x, y, np.mod, math_ops.mod)
-    self._compareBoth(x, y, np.subtract, _SUB)
-    self._compareBoth(x, y, np.multiply, _MUL)
-    self._compareBoth(x, y, np.true_divide, _TRUEDIV)
-    self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
-    self._compareBoth(x, y, np.mod, _MOD)
-
-  def testComplex64Basic(self):
-    x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
-        np.complex64)
-    y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
-        np.complex64)
-    self._compareBoth(x, y, np.add, math_ops.add)
-    self._compareBoth(x, y, np.subtract, math_ops.subtract)
-    self._compareBoth(x, y, np.multiply, math_ops.multiply)
-    self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
-    self._compareBoth(x, y, np.add, _ADD)
-    self._compareBoth(x, y, np.subtract, _SUB)
-    self._compareBoth(x, y, np.multiply, _MUL)
-    self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
-
-  def testComplex128Basic(self):
-    x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
-        np.complex128)
-    y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
-        np.complex128)
-    self._compareBoth(x, y, np.add, math_ops.add)
-    self._compareBoth(x, y, np.subtract, math_ops.subtract)
-    self._compareBoth(x, y, np.multiply, math_ops.multiply)
-    self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
-    self._compareBoth(x, y, np.add, _ADD)
-    self._compareBoth(x, y, np.subtract, _SUB)
-    self._compareBoth(x, y, np.multiply, _MUL)
-    self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
-
-  def testStringComparison(self):
-    x = np.array([["abc", "bh"], ["c", ""]])
-    y = np.array([["abc", "bh"], ["def", "hi"]])
-    with self.test_session(use_gpu=False) as sess:
-      cmp_eq = math_ops.equal(x, y)
-      cmp_not_eq = math_ops.not_equal(x, y)
-      values = sess.run([cmp_eq, cmp_not_eq])
-      self.assertAllEqual([[True, True], [False, False]], values[0])
-      self.assertAllEqual([[False, False], [True, True]], values[1])
-
-  def testString(self):
-    x = np.array(
-        [["x_0_0", "x_0_1", "x_0_2"], ["x_1_0", "x_1_1", "x_1_2"],
-         ["x_2_0", "x_2_1", "x_2_2"]],
-        dtype=np.object)
-    y = np.array(
-        [["y_0_0", "y_0_1", "y_0_2"], ["y_1_0", "y_1_1", "y_1_2"],
-         ["y_2_0", "y_2_1", "y_2_2"]],
-        dtype=np.object)
-    z = np.array([["z_0", "z_1", "z_2"]], dtype=np.object)
-    w = np.array("w", dtype=np.object)
-    self._compareCpu(x, y, _ADD, _ADD)
-    self._compareCpu(x, z, _ADD, _ADD)
-    self._compareCpu(x, w, _ADD, _ADD)
-    self._compareCpu(z, w, _ADD, _ADD)
-
-  def _compareBCast(self, xs, ys, dtype, np_func, tf_func):
-    if dtype in (np.complex64, np.complex128):
-      x = (1 + np.linspace(0, 2 + 3j, np.prod(xs))).astype(dtype).reshape(xs)
-      y = (1 + np.linspace(0, 2 - 2j, np.prod(ys))).astype(dtype).reshape(ys)
-    else:
-      x = (1 + np.linspace(0, 5, np.prod(xs))).astype(dtype).reshape(xs)
-      y = (1 + np.linspace(0, 5, np.prod(ys))).astype(dtype).reshape(ys)
-    self._compareCpu(x, y, np_func, tf_func)
-    if x.dtype in (np.float16, np.float32, np.float64):
-      # TODO(aselle): Make the test work for dtypes:
-      #     (np.complex64, np.complex128).
-      if tf_func not in (_FLOORDIV, math_ops.floordiv):
-        if x.dtype == np.float16:
-          # Compare fp16 theoretical gradients to fp32 numerical gradients,
-          # since fp16 numerical gradients are too imprecise unless great
-          # care is taken with choosing the inputs and the delta. This is
-          # a weaker check (in particular, it does not test the op itself,
-          # only its gradient), but it's much better than nothing.
-          self._compareGradientX(x, y, np_func, tf_func, np.float)
-          self._compareGradientY(x, y, np_func, tf_func, np.float)
-        else:
-          self._compareGradientX(x, y, np_func, tf_func)
-          self._compareGradientY(x, y, np_func, tf_func)
-      self._compareGpu(x, y, np_func, tf_func)
-
-  # TODO(josh11b,vrv): Refactor this to use parameterized tests.
-  def _testBCastByFunc(self, funcs, xs, ys):
-    dtypes = [
-        np.float16,
-        np.float32,
-        np.float64,
-        np.int32,
-        np.int64,
-        np.complex64,
-        np.complex128,
-    ]
-    for dtype in dtypes:
-      for (np_func, tf_func) in funcs:
-        if (dtype in (np.complex64, np.complex128) and
-            tf_func in (_FLOORDIV, math_ops.floordiv)):
-          continue  # floordiv makes no sense for complex numbers
-        self._compareBCast(xs, ys, dtype, np_func, tf_func)
-        self._compareBCast(ys, xs, dtype, np_func, tf_func)
-
-  def _testBCastA(self, xs, ys):
-    funcs = [
-        (np.add, math_ops.add),
-        (np.add, _ADD),
-    ]
-    self._testBCastByFunc(funcs, xs, ys)
-
-  def _testBCastB(self, xs, ys):
-    funcs = [
-        (np.subtract, math_ops.subtract),
-        (np.subtract, _SUB),
-        (np.power, math_ops.pow),
-    ]
-    self._testBCastByFunc(funcs, xs, ys)
-
-  def _testBCastC(self, xs, ys):
-    funcs = [
-        (np.multiply, math_ops.multiply),
-        (np.multiply, _MUL),
-    ]
-    self._testBCastByFunc(funcs, xs, ys)
-
-  def _testBCastD(self, xs, ys):
-    funcs = [
-        (np.true_divide, math_ops.truediv),
-        (np.floor_divide, math_ops.floordiv),
-        (np.true_divide, _TRUEDIV),
-        (np.floor_divide, _FLOORDIV),
-    ]
-    self._testBCastByFunc(funcs, xs, ys)
-
-  def testBCast_0A(self):
-    self._testBCastA([1, 3, 2], [1])
-
-  def testBCast_0B(self):
-    self._testBCastB([1, 3, 2], [1])
-
-  def testBCast_0C(self):
-    self._testBCastC([1, 3, 2], [1])
-
-  def testBCast_0D(self):
-    self._testBCastD([1, 3, 2], [1])
-
-  def testBCast_1A(self):
-    self._testBCastA([1, 3, 2], [2])
-
-  def testBCast_1B(self):
-    self._testBCastB([1, 3, 2], [2])
-
-  def testBCast_1C(self):
-    self._testBCastC([1, 3, 2], [2])
-
-  def testBCast_1D(self):
-    self._testBCastD([1, 3, 2], [2])
-
-  def testBCast_2A(self):
-    self._testBCastA([1, 3, 2], [3, 2])
-
-  def testBCast_2B(self):
-    self._testBCastB([1, 3, 2], [3, 2])
-
-  def testBCast_2C(self):
-    self._testBCastC([1, 3, 2], [3, 2])
-
-  def testBCast_2D(self):
-    self._testBCastD([1, 3, 2], [3, 2])
-
-  def testBCast_3A(self):
-    self._testBCastA([1, 3, 2], [3, 1])
-
-  def testBCast_3B(self):
-    self._testBCastB([1, 3, 2], [3, 1])
-
-  def testBCast_3C(self):
-    self._testBCastC([1, 3, 2], [3, 1])
-
-  def testBCast_3D(self):
-    self._testBCastD([1, 3, 2], [3, 1])
-
-  def testBCast_4A(self):
-    self._testBCastA([1, 3, 2], [1, 3, 2])
-
-  def testBCast_4B(self):
-    self._testBCastB([1, 3, 2], [1, 3, 2])
-
-  def testBCast_4C(self):
-    self._testBCastC([1, 3, 2], [1, 3, 2])
-
-  def testBCast_4D(self):
-    self._testBCastD([1, 3, 2], [1, 3, 2])
-
-  def testBCast_5A(self):
-    self._testBCastA([1, 3, 2], [2, 3, 1])
-
-  def testBCast_5B(self):
-    self._testBCastB([1, 3, 2], [2, 3, 1])
-
-  def testBCast_5C(self):
-    self._testBCastC([1, 3, 2], [2, 3, 1])
-
-  def testBCast_5D(self):
-    self._testBCastD([1, 3, 2], [2, 3, 1])
-
-  def testBCast_6A(self):
-    self._testBCastA([1, 3, 2], [2, 1, 1])
-
-  def testBCast_6B(self):
-    self._testBCastB([1, 3, 2], [2, 1, 1])
-
-  def testBCast_6C(self):
-    self._testBCastC([1, 3, 2], [2, 1, 1])
-
-  def testBCast_6D(self):
-    self._testBCastD([1, 3, 2], [2, 1, 1])
-
-  def testBCast_7A(self):
-    self._testBCastA([1, 3, 2], [1, 3, 1])
-
-  def testBCast_7B(self):
-    self._testBCastB([1, 3, 2], [1, 3, 1])
-
-  def testBCast_7C(self):
-    self._testBCastC([1, 3, 2], [1, 3, 1])
-
-  def testBCast_7D(self):
-    self._testBCastD([1, 3, 2], [1, 3, 1])
-
-  def testBCast_8A(self):
-    self._testBCastA([2, 1, 5], [2, 3, 1])
-
-  def testBCast_8B(self):
-    self._testBCastB([2, 1, 5], [2, 3, 1])
-
-  def testBCast_8C(self):
-    self._testBCastC([2, 1, 5], [2, 3, 1])
-
-  def testBCast_8D(self):
-    self._testBCastD([2, 1, 5], [2, 3, 1])
-
-  def testBCast_9A(self):
-    self._testBCastA([2, 0, 5], [2, 0, 1])
-
-  def testBCast_9B(self):
-    self._testBCastB([2, 0, 5], [2, 0, 1])
-
-  def testBCast_9C(self):
-    self._testBCastC([2, 0, 5], [2, 0, 1])
-
-  def testBCast_9D(self):
-    self._testBCastD([2, 0, 5], [2, 0, 1])
-
-  def testBCast_10A(self):
-    self._testBCastA([2, 3, 0], [2, 3, 1])
-
-  def testBCast_10B(self):
-    self._testBCastB([2, 3, 0], [2, 3, 1])
-
-  def testBCast_10C(self):
-    self._testBCastC([2, 3, 0], [2, 3, 1])
-
-  def testBCast_10D(self):
-    self._testBCastD([2, 3, 0], [2, 3, 1])
-
-  def testBCast_11A(self):
-    self._testBCastA([1, 3, 2], [1, 3, 2])
-
-  def testBCast_11B(self):
-    self._testBCastB([1, 3, 2], [1, 3, 2])
-
-  def testBCast_11C(self):
-    self._testBCastC([1, 3, 2], [1, 3, 2])
-
-  def testBCast_11D(self):
-    self._testBCastD([1, 3, 2], [1, 3, 2])
-
-  def testBCast_12A(self):
-    self._testBCastA([1, 1, 1, 1, 3, 2], [1, 3, 2])
-
-  def testBCast_12B(self):
-    self._testBCastB([1, 1, 1, 1, 3, 2], [1, 3, 2])
-
-  def testBCast_12C(self):
-    self._testBCastC([1, 1, 1, 1, 3, 2], [1, 3, 2])
-
-  def testBCast_12D(self):
-    self._testBCastD([1, 1, 1, 1, 3, 2], [1, 3, 2])
-
-  def testBCast_13A(self):
-    self._testBCastA([1, 3, 2, 1, 1], [1])
-
-  def testBCast_13B(self):
-    self._testBCastB([1, 3, 2, 1, 1], [1])
-
-  def testBCast_13C(self):
-    self._testBCastC([1, 3, 2, 1, 1], [1])
-
-  def testBCast_13D(self):
-    self._testBCastD([1, 3, 2, 1, 1], [1])
-
-  def testBCast_14A(self):
-    self._testBCastA([2, 3, 1, 1, 5], [1])
-
-  def testBCast_14B(self):
-    self._testBCastB([2, 3, 1, 1, 5], [1])
-
-  def testBCast_14C(self):
-    self._testBCastC([2, 3, 1, 1, 5], [1])
-
-  def testBCast_14D(self):
-    self._testBCastD([2, 3, 1, 1, 5], [1])
-
-  def testBCast_15A(self):
-    self._testBCastA([10, 3, 1, 2], [3, 1, 2])
-
-  def testBCast_15B(self):
-    self._testBCastB([10, 3, 1, 2], [3, 1, 2])
-
-  def testBCast_15C(self):
-    self._testBCastC([10, 3, 1, 2], [3, 1, 2])
-
-  def testBCast_15D(self):
-    self._testBCastD([10, 3, 1, 2], [3, 1, 2])
-
-  def testMismatchedDimensions(self):
-    for func in [
-        math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.div, _ADD,
-        _SUB, _MUL, _TRUEDIV, _FLOORDIV
-    ]:
-      with self.assertRaisesWithPredicateMatch(
-          ValueError, lambda e: "Dimensions must" in str(e)):
-        func(
-            ops.convert_to_tensor([10.0, 20.0, 30.0]),
-            ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
-
-  def testZeroPowGrad(self):
-    with self.cached_session():
-      for dtype in (np.float16, np.float32, np.float64, np.complex64,
-                    np.complex128):
-        x = constant_op.constant(0.0, dtype=dtype)
-        y = constant_op.constant(2.0, dtype=dtype)
-        z = math_ops.pow(x, y)
-        error = gradient_checker.compute_gradient_error(y, [], z, [])
-        self.assertEqual(error, 0)
-
-  def testComplexPowGrad(self):
-    with self.cached_session():
-      for dtype in np.complex64, np.complex128:
-        for base in 2.0, -2.0:
-          x = constant_op.constant(base, dtype=dtype)
-          y = constant_op.constant(2.0, dtype=dtype)
-          z = math_ops.pow(x, y)
-          error = gradient_checker.compute_gradient_error(y, [], z, [])
-          self.assertLess(error, 2e-4)
-
-  def testAtan2SpecialValues(self):
-    x1l, x2l = zip((+0.0, +0.0), (+0.0, -0.0), (-0.0, +0.0), (-0.0, -0.0),
-                   (1.2345, float("inf")), (1.2345, -float("inf")),
-                   (-4.321, float("inf")), (-4.125, -float("inf")),
-                   (float("inf"), float("inf")), (float("inf"), -float("inf")),
-                   (-float("inf"), float("inf")),
-                   (-float("inf"), -float("inf")))
-    for dtype in np.float32, np.float64:
-      x1 = np.array(x1l).astype(dtype)
-      x2 = np.array(x2l).astype(dtype)
-      self._compareCpu(x1, x2, np.arctan2, math_ops.atan2)
-      self._compareGpu(x1, x2, np.arctan2, math_ops.atan2)
-
-  def testPowNegativeExponent(self):
-    for dtype in [np.int32, np.int64]:
-      with self.test_session(use_gpu=False) as sess:
-        with self.assertRaisesRegexp(
-            errors_impl.InvalidArgumentError,
-            "Integers to negative integer powers are not allowed"):
-          x = np.array([5, 2]).astype(dtype)
-          y = np.array([-2, 3]).astype(dtype)
-          sess.run(math_ops.pow(x, y))
-
-      with self.test_session(use_gpu=False) as sess:
-        with self.assertRaisesRegexp(
-            errors_impl.InvalidArgumentError,
-            "Integers to negative integer powers are not allowed"):
-          x = np.array([5, 2]).astype(dtype)
-          y = np.array([2, -3]).astype(dtype)
-          sess.run(math_ops.pow(x, y))
-
-      with self.test_session(use_gpu=False) as sess:
-        with self.assertRaisesRegexp(
-            errors_impl.InvalidArgumentError,
-            "Integers to negative integer powers are not allowed"):
-          x = np.array([5, 2]).astype(dtype)
-          y = -3
-          sess.run(math_ops.pow(x, y))
-
-
 class ComparisonOpTest(test.TestCase):
 
   def _compareScalar(self, func, x, y, dtype):
diff --git a/tensorflow/python/kernel_tests/cwise_ops_unary_test.py b/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
new file mode 100644
index 0000000..77f1827
--- /dev/null
+++ b/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
@@ -0,0 +1,541 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for unary coefficient-wise operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+_NEG = lambda x: -x
+_ABS = abs
+
+
+# TODO(zongheng): it'd be great to factor out this function and various random
+# SparseTensor gen funcs.
+def _sparsify(x, thresh=0.5, index_dtype=np.int64):
+  x[x < thresh] = 0
+
+  non_zero = np.where(x)
+  x_indices = np.vstack(non_zero).astype(index_dtype).T
+  x_values = x[non_zero]
+  x_shape = x.shape
+
+  return sparse_tensor.SparseTensor(
+      indices=x_indices, values=x_values, dense_shape=x_shape), x_values
+
+
+def _default_tolerance(dtype):
+  """Returns a sensible default tolerance for comparing results of a given type.
+
+  Args:
+    dtype: A datatype.
+  """
+  if dtype == np.float16:
+    return 5e-3
+  elif dtype in (np.float32, np.complex64):
+    return 1e-3
+  elif dtype in (np.float64, np.complex128):
+    return 1e-5
+  else:
+    return None  # Fail fast for unexpected types
+
+
+class UnaryOpTest(test.TestCase):
+
+  def _compareCpu(self, x, np_func, tf_func, grad_rtol=None, grad_atol=None):
+    if grad_rtol is None:
+      grad_rtol = _default_tolerance(x.dtype)
+    if grad_atol is None:
+      grad_atol = _default_tolerance(x.dtype)
+    np_ans = np_func(x)
+    with self.test_session(use_gpu=False):
+      inx = ops.convert_to_tensor(x)
+      if x.dtype in (np.float32, np.float64,
+                     dtypes_lib.bfloat16.as_numpy_dtype):
+        y = 1.1 * tf_func(inx)
+        np_ans *= 1.1
+      else:
+        y = tf_func(inx)
+      tf_cpu = y.eval()
+      self.assertShapeEqual(np_ans, y)
+      if x.dtype == np.float16:
+        self.assertAllClose(np_ans, tf_cpu, rtol=1e-3, atol=1e-3)
+      elif x.dtype == dtypes_lib.bfloat16.as_numpy_dtype:
+        self.assertAllClose(np_ans, tf_cpu, rtol=1e-2, atol=1e-2)
+      else:
+        self.assertAllClose(np_ans, tf_cpu)
+
+      if x.dtype in (np.complex64, np.complex128) and tf_func == math_ops.sign:
+        return  # Return early
+
+      if x.dtype == np.float16:
+        s = list(np.shape(x))
+        jacob_t, _ = gradient_checker.compute_gradient(
+            inx, s, y, s, x_init_value=x)
+        xf = x.astype(np.float)
+        inxf = ops.convert_to_tensor(xf)
+        yf = tf_func(inxf)
+        _, jacob_n = gradient_checker.compute_gradient(
+            inxf, s, yf, s, x_init_value=xf, delta=1e-2)
+        jacob_n = jacob_n.astype(np.float16)
+        self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
+      elif x.dtype in (np.float32, np.complex64):
+        s = list(np.shape(x))
+        jacob_t, jacob_n = gradient_checker.compute_gradient(
+            inx, s, y, s, x_init_value=x, delta=1e-3)
+        self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
+      elif x.dtype in (np.float64, np.complex128):
+        s = list(np.shape(x))
+        jacob_t, jacob_n = gradient_checker.compute_gradient(
+            inx, s, y, s, x_init_value=x, delta=1e-5)
+        self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
+
+  def _check(self, result_tensor, result_np, input_sp_t, tol):
+    self.assertTrue(isinstance(result_tensor, sparse_tensor.SparseTensor))
+    self.assertTrue(isinstance(input_sp_t, sparse_tensor.SparseTensor))
+    self.assertAllEqual(input_sp_t.indices.eval(), result_tensor.indices.eval())
+    self.assertAllEqual(input_sp_t.dense_shape.eval(),
+                        result_tensor.dense_shape.eval())
+    if tol is None:
+      self.assertAllClose(result_np, result_tensor.values.eval())
+    else:
+      self.assertAllClose(
+          result_np, result_tensor.values.eval(), rtol=tol, atol=tol)
+
+  def _compareSparseCpu(self, x, np_func, tf_func, tol):
+    x_sp, x_sp_vals = _sparsify(x)
+    res_np = np_func(x_sp_vals)
+    with self.test_session(use_gpu=False):
+      self._check(tf_func(x_sp), res_np, x_sp, tol)
+
+  def _compareGpu(self, x, np_func, tf_func):
+    np_ans = np_func(x)
+    with self.test_session(force_gpu=test_util.is_gpu_available()):
+      result = tf_func(ops.convert_to_tensor(x))
+      tf_gpu = result.eval()
+    if x.dtype == np.float16:
+      self.assertAllClose(np_ans, tf_gpu, rtol=1e-3, atol=1e-3)
+    else:
+      self.assertAllClose(np_ans, tf_gpu)
+    # TODO(zhifengc/ke): make gradient checker work on GPU.
+
+  def _compareSparseGpu(self, x, np_func, tf_func, tol):
+    x_sp, x_sp_vals = _sparsify(x)
+    res_np = np_func(x_sp_vals)
+    with self.test_session(force_gpu=test_util.is_gpu_available()):
+      self._check(tf_func(x_sp), res_np, x_sp, tol)
+
+  def _compareBoth(self, x, np_func, tf_func):
+    self._compareCpu(x, np_func, tf_func)
+    self._compareGpu(x, np_func, tf_func)
+
+  def _compareBothSparse(self, x, np_func, tf_func, tol=None):
+    self._compareSparseCpu(x, np_func, tf_func, tol)
+    self._compareSparseGpu(x, np_func, tf_func, tol)
+
+  def _inv(self, x):
+    return 1.0 / x
+
+  def _rsqrt(self, x):
+    return self._inv(np.sqrt(x))
+
+  def _sigmoid(self, x):
+    return 1.0 / (1.0 + np.exp(-x))
+
+  def _log_sigmoid(self, x):
+    return np.log(self._sigmoid(x))
+
+  def _replace_domain_error_with_inf(self, fn):
+
+    def func(x):
+      try:
+        return fn(x)
+      except ValueError as e:
+        if "domain error" in str(e):
+          return np.inf * np.ones_like(x)
+        else:
+          raise e
+
+    return func
+
+  def testFloatBasic(self):
+    x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
+    w = x - x.min() + 1.02  # all greater than 1
+    y = (x + .5).astype(np.float32)  # no zero
+    z = (x + 15.5).astype(np.float32)  # all positive
+    k = np.arange(-0.90, 0.90, 0.25).astype(np.float32)  # between -1 and 1
+
+    self._compareBoth(x, np.abs, math_ops.abs)
+    self._compareBoth(x, np.abs, _ABS)
+    self._compareBoth(x, np.negative, math_ops.negative)
+    self._compareBoth(x, np.negative, _NEG)
+    self._compareBoth(y, self._inv, math_ops.reciprocal)
+    self._compareBoth(x, np.square, math_ops.square)
+    self._compareBoth(z, np.sqrt, math_ops.sqrt)
+    self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
+    self._compareBoth(x, np.exp, math_ops.exp)
+    self._compareBoth(x, np.expm1, math_ops.expm1)
+    self._compareBoth(z, np.log, math_ops.log)
+    self._compareBoth(z, np.log1p, math_ops.log1p)
+    self._compareBoth(x, np.sinh, math_ops.sinh)
+    self._compareBoth(x, np.cosh, math_ops.cosh)
+    self._compareBoth(x, np.tanh, math_ops.tanh)
+    self._compareBoth(x, np.arcsinh, math_ops.asinh)
+    self._compareBoth(w, np.arccosh, math_ops.acosh)
+    self._compareBoth(k, np.arctanh, math_ops.atanh)
+    self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
+    self._compareBoth(x, self._log_sigmoid, math_ops.log_sigmoid)
+    self._compareBoth(y, np.sign, math_ops.sign)
+    self._compareBoth(x, np.sin, math_ops.sin)
+    self._compareBoth(x, np.cos, math_ops.cos)
+    self._compareBoth(k, np.arcsin, math_ops.asin)
+    self._compareBoth(k, np.arccos, math_ops.acos)
+    self._compareBoth(x, np.arctan, math_ops.atan)
+    self._compareBoth(x, np.tan, math_ops.tan)
+    self._compareBoth(
+        y, np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
+        math_ops.lgamma)
+    self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
+    self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
+    try:
+      from scipy import special  # pylint: disable=g-import-not-at-top
+      self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+      self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+    except ImportError as e:
+      tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+    self._compareBothSparse(x, np.abs, math_ops.abs)
+    self._compareBothSparse(x, np.negative, math_ops.negative)
+    self._compareBothSparse(x, np.square, math_ops.square)
+    self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
+    self._compareBothSparse(x, np.tanh, math_ops.tanh)
+    self._compareBothSparse(y, np.sign, math_ops.sign)
+    self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
+
+  def testFloatTanhEdge(self):
+    x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
+    self._compareBoth(x, np.tanh, math_ops.tanh)
+    x = np.arange(-40, -40 + 6).reshape(6).astype(np.float32)
+    self._compareBoth(x, np.tanh, math_ops.tanh)
+
+  def testFloatEmpty(self):
+    x = np.empty((2, 0, 5), dtype=np.float32)
+    self._compareBoth(x, np.abs, math_ops.abs)
+    self._compareBoth(x, np.abs, _ABS)
+    self._compareBoth(x, np.negative, math_ops.negative)
+    self._compareBoth(x, np.negative, _NEG)
+    self._compareBoth(x, self._inv, math_ops.reciprocal)
+    self._compareBoth(x, np.square, math_ops.square)
+    self._compareBoth(x, np.sqrt, math_ops.sqrt)
+    self._compareBoth(x, self._rsqrt, math_ops.rsqrt)
+    self._compareBoth(x, np.exp, math_ops.exp)
+    self._compareBoth(x, np.expm1, math_ops.expm1)
+    self._compareBoth(x, np.log, math_ops.log)
+    self._compareBoth(x, np.log1p, math_ops.log1p)
+    self._compareBoth(x, np.sinh, math_ops.sinh)
+    self._compareBoth(x, np.arcsinh, math_ops.asinh)
+    self._compareBoth(x, np.cosh, math_ops.cosh)
+    self._compareBoth(x, np.tanh, math_ops.tanh)
+    self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
+    self._compareBoth(x, np.sign, math_ops.sign)
+    self._compareBoth(x, np.sin, math_ops.sin)
+    self._compareBoth(x, np.cos, math_ops.cos)
+    # Can't use vectorize below, so just use some arbitrary function
+    self._compareBoth(x, np.sign, math_ops.lgamma)
+    self._compareBoth(x, np.sign, math_ops.erf)
+    self._compareBoth(x, np.sign, math_ops.erfc)
+    self._compareBoth(x, np.tan, math_ops.tan)
+    self._compareBoth(x, np.arcsin, math_ops.asin)
+    self._compareBoth(x, np.arccos, math_ops.acos)
+    self._compareBoth(x, np.arctan, math_ops.atan)
+    try:
+      from scipy import special  # pylint: disable=g-import-not-at-top
+      self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+      self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+    except ImportError as e:
+      tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+    self._compareBothSparse(x, np.abs, math_ops.abs)
+    self._compareBothSparse(x, np.negative, math_ops.negative)
+    self._compareBothSparse(x, np.square, math_ops.square)
+    self._compareBothSparse(x, np.sqrt, math_ops.sqrt, tol=1e-3)
+    self._compareBothSparse(x, np.tanh, math_ops.tanh)
+    self._compareBothSparse(x, np.sign, math_ops.sign)
+    self._compareBothSparse(x, np.sign, math_ops.erf)
+
+  def testDoubleBasic(self):
+    x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
+    w = x - x.min() + 1.02  # all greater than 1
+    y = (x + .5).astype(np.float64)  # no zero
+    z = (x + 15.5).astype(np.float64)  # all positive
+    k = np.arange(-0.90, 0.90,
+                  0.35).reshape(1, 3, 2).astype(np.float64)  # between -1 and 1
+    self._compareBoth(x, np.abs, math_ops.abs)
+    self._compareBoth(x, np.abs, _ABS)
+    self._compareBoth(x, np.negative, math_ops.negative)
+    self._compareBoth(x, np.negative, _NEG)
+    self._compareBoth(y, self._inv, math_ops.reciprocal)
+    self._compareBoth(x, np.square, math_ops.square)
+    self._compareBoth(z, np.sqrt, math_ops.sqrt)
+    self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
+    self._compareBoth(x, np.exp, math_ops.exp)
+    self._compareBoth(x, np.expm1, math_ops.expm1)
+    self._compareBoth(z, np.log, math_ops.log)
+    self._compareBoth(z, np.log1p, math_ops.log1p)
+    self._compareBoth(x, np.sinh, math_ops.sinh)
+    self._compareBoth(x, np.cosh, math_ops.cosh)
+    self._compareBoth(x, np.tanh, math_ops.tanh)
+    self._compareBoth(x, np.arcsinh, math_ops.asinh)
+    self._compareBoth(w, np.arccosh, math_ops.acosh)
+    self._compareBoth(k, np.arctanh, math_ops.atanh)
+    self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
+    self._compareBoth(y, np.sign, math_ops.sign)
+    self._compareBoth(x, np.sin, math_ops.sin)
+    self._compareBoth(x, np.cos, math_ops.cos)
+    self._compareBoth(
+        y, np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
+        math_ops.lgamma)
+    self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
+    self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
+    self._compareBoth(x, np.arctan, math_ops.atan)
+    self._compareBoth(k, np.arcsin, math_ops.asin)
+    self._compareBoth(k, np.arccos, math_ops.acos)
+    self._compareBoth(k, np.tan, math_ops.tan)
+    try:
+      from scipy import special  # pylint: disable=g-import-not-at-top
+      self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+      self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+    except ImportError as e:
+      tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+    self._compareBothSparse(x, np.abs, math_ops.abs)
+    self._compareBothSparse(x, np.negative, math_ops.negative)
+    self._compareBothSparse(x, np.square, math_ops.square)
+    self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
+    self._compareBothSparse(x, np.tanh, math_ops.tanh)
+    self._compareBothSparse(y, np.sign, math_ops.sign)
+    self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
+
+  def testHalfBasic(self):
+    x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16)
+    y = (x + .5).astype(np.float16)  # no zero
+    z = (x + 15.5).astype(np.float16)  # all positive
+    self._compareBoth(x, np.abs, math_ops.abs)
+    self._compareBoth(x, np.abs, _ABS)
+    self._compareBoth(x, np.negative, math_ops.negative)
+    self._compareBoth(x, np.negative, _NEG)
+    self._compareBoth(y, self._inv, math_ops.reciprocal)
+    self._compareBoth(x, np.square, math_ops.square)
+    self._compareBoth(z, np.sqrt, math_ops.sqrt)
+    self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
+    self._compareBoth(x, np.exp, math_ops.exp)
+    self._compareBoth(x, np.expm1, math_ops.expm1)
+    self._compareBoth(z, np.log, math_ops.log)
+    self._compareBoth(z, np.log1p, math_ops.log1p)
+    self._compareBoth(x, np.tanh, math_ops.tanh)
+    self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
+    self._compareBoth(y, np.sign, math_ops.sign)
+    self._compareBoth(x, np.sin, math_ops.sin)
+    self._compareBoth(x, np.cos, math_ops.cos)
+    self._compareBoth(
+        y, np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
+        math_ops.lgamma)
+    self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
+    self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
+    try:
+      from scipy import special  # pylint: disable=g-import-not-at-top
+      self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+      self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+    except ImportError as e:
+      tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+    self._compareBothSparse(x, np.abs, math_ops.abs)
+    self._compareBothSparse(x, np.negative, math_ops.negative)
+    self._compareBothSparse(x, np.square, math_ops.square)
+    self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
+    self._compareBothSparse(x, np.tanh, math_ops.tanh)
+    self._compareBothSparse(y, np.sign, math_ops.sign)
+    self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf, tol=1e-3)
+
+  def testInt32Basic(self):
+    x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
+    self._compareCpu(x, np.abs, math_ops.abs)
+    self._compareCpu(x, np.abs, _ABS)
+    self._compareBoth(x, np.negative, math_ops.negative)
+    self._compareBoth(x, np.negative, _NEG)
+    self._compareBoth(x, np.square, math_ops.square)
+    self._compareCpu(x, np.sign, math_ops.sign)
+
+    self._compareBothSparse(x, np.abs, math_ops.abs)
+    self._compareBothSparse(x, np.negative, math_ops.negative)
+    self._compareBothSparse(x, np.square, math_ops.square)
+    self._compareBothSparse(x, np.sign, math_ops.sign)
+
+  def testInt64Basic(self):
+    x = np.arange(-6 << 40, 6 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
+    self._compareCpu(x, np.abs, math_ops.abs)
+    self._compareCpu(x, np.abs, _ABS)
+    self._compareCpu(x, np.negative, math_ops.negative)
+    self._compareCpu(x, np.negative, _NEG)
+    self._compareCpu(x, np.sign, math_ops.sign)
+
+    self._compareBothSparse(x, np.abs, math_ops.abs)
+    self._compareBothSparse(x, np.negative, math_ops.negative)
+    self._compareBothSparse(x, np.sign, math_ops.sign)
+
+  def testInt64Square(self):
+    x = np.arange(-6 << 20, 6 << 20, 2 << 20).reshape(1, 3, 2).astype(np.int64)
+    self._compareCpu(x, np.square, math_ops.square)
+    self._compareBothSparse(x, np.square, math_ops.square)
+
+  def testComplex64Basic(self):
+    x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
+        np.complex64)
+    y = x + np.complex(0.5, 0.5)  # no zeros
+    self._compareBoth(x, np.abs, math_ops.abs)
+    self._compareBoth(x, np.abs, _ABS)
+    self._compareBoth(x, np.negative, math_ops.negative)
+    self._compareBoth(x, np.negative, _NEG)
+    self._compareCpu(y, self._inv, math_ops.reciprocal)
+    self._compareCpu(x, np.square, math_ops.square)
+    self._compareCpu(y, np.sqrt, math_ops.sqrt)
+    self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
+    self._compareBoth(x, np.exp, math_ops.exp)
+    self._compareCpu(x, np.expm1, math_ops.expm1)
+    self._compareCpu(y, np.log, math_ops.log)
+    self._compareCpu(y, np.log1p, math_ops.log1p)
+    self._compareCpu(x, np.sinh, math_ops.sinh)
+    self._compareCpu(x, np.cosh, math_ops.cosh)
+    self._compareCpu(x, np.tanh, math_ops.tanh)
+
+    # Complex64 versions of asinh() and acosh() in libstdc++ only have 6 digits
+    # of precision.
+    # Small gradient values + low precision --> High relative error
+    self._compareCpu(y, np.arcsinh, math_ops.asinh, grad_rtol=1e-2)
+    self._compareCpu(y, np.arccosh, math_ops.acosh, grad_rtol=1e-2)
+
+    self._compareCpu(y, np.arctanh, math_ops.atanh)
+    self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
+    self._compareCpu(x, np.sin, math_ops.sin)
+    self._compareCpu(x, np.cos, math_ops.cos)
+
+    self._compareBothSparse(x, np.abs, math_ops.abs)
+    self._compareBothSparse(x, np.negative, math_ops.negative)
+    self._compareBothSparse(x, np.square, math_ops.square)
+    self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
+    self._compareBothSparse(x, np.tanh, math_ops.tanh)
+
+    # Numpy uses an incorrect definition of sign; use the right one instead.
+    def complex_sign(x):
+      return x / np.abs(x)
+
+    self._compareBoth(y, complex_sign, math_ops.sign)
+    self._compareBothSparse(y, complex_sign, math_ops.sign)
+
+  def testComplex128Basic(self):
+    x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
+        np.complex128)
+    y = x + np.complex(0.5, 0.5)  # no zeros
+    self._compareBoth(x, np.abs, math_ops.abs)
+    self._compareBoth(x, np.abs, _ABS)
+    self._compareBoth(x, np.negative, math_ops.negative)
+    self._compareBoth(x, np.negative, _NEG)
+    self._compareCpu(y, self._inv, math_ops.reciprocal)
+    self._compareCpu(x, np.square, math_ops.square)
+    self._compareCpu(y, np.sqrt, math_ops.sqrt)
+    self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
+    self._compareBoth(x, np.exp, math_ops.exp)
+    self._compareCpu(x, np.expm1, math_ops.expm1)
+    self._compareCpu(y, np.log, math_ops.log)
+    self._compareCpu(y, np.log1p, math_ops.log1p)
+    self._compareCpu(x, np.sinh, math_ops.sinh)
+    self._compareCpu(x, np.cosh, math_ops.cosh)
+    self._compareCpu(x, np.tanh, math_ops.tanh)
+    self._compareCpu(y, np.arcsinh, math_ops.asinh)
+    self._compareCpu(y, np.arccosh, math_ops.acosh)
+    self._compareCpu(y, np.arctanh, math_ops.atanh)
+    self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
+    self._compareCpu(x, np.sin, math_ops.sin)
+    self._compareCpu(x, np.cos, math_ops.cos)
+
+    self._compareBothSparse(x, np.abs, math_ops.abs)
+    self._compareBothSparse(x, np.negative, math_ops.negative)
+    self._compareBothSparse(x, np.square, math_ops.square)
+    self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
+    self._compareBothSparse(x, np.tanh, math_ops.tanh)
+
+    # Numpy uses an incorrect definition of sign; use the right one instead.
+    def complex_sign(x):
+      return x / np.abs(x)
+
+    self._compareBoth(y, complex_sign, math_ops.sign)
+    self._compareBothSparse(y, complex_sign, math_ops.sign)
+
+  def testGradGrad(self):
+    np.random.seed(7)
+    shape = (5,)
+    dtype_tols = [(np.float32, 5e-4), (np.float64, 1e-6), (np.complex64, 5e-4),
+                  (np.complex128, 1e-6)]
+    op_range = [
+        (gen_math_ops.reciprocal_grad, [-2, 2]),
+        (gen_math_ops.rsqrt_grad, [0.1, 3]),
+        (gen_math_ops.sigmoid_grad, [-2, 2]),
+        (gen_math_ops.sqrt_grad, [0.1, 3]),
+        (gen_math_ops.tanh_grad, [-2, 2]),
+    ]
+
+    def rand(dtype, real_range):
+      x = np.random.uniform(
+          real_range[0], real_range[1], size=shape[0]).astype(dtype)
+      if dtype in (np.complex64, np.complex128):
+        x += 1j * np.random.uniform(-2, 2, size=shape[0]).astype(dtype)
+      return x
+
+    for op, real_range in op_range:
+      with self.cached_session():
+        for dtype, tol in dtype_tols:
+          x = constant_op.constant(rand(dtype, real_range))
+          y = constant_op.constant(rand(dtype, real_range))
+          z = op(x, y)
+          grads = gradient_checker.compute_gradient(
+              [x, y], [shape, shape],
+              z,
+              shape,
+              x_init_value=[rand(dtype, real_range),
+                            rand(dtype, real_range)])
+          if isinstance(grads, tuple):
+            grads = [grads]
+          for analytical, numerical in grads:
+            self.assertAllClose(analytical, numerical, rtol=tol, atol=tol)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
index cd6a34d..e52f303 100644
--- a/tensorflow/python/kernel_tests/linalg_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -120,7 +120,7 @@
       delta = epsilon**(1.0 / 3.0)
       # tolerance obtained by looking at actual differences using
       # np.linalg.norm(theoretical-numerical, np.inf) on -mavx build
-      tol = 1e-6 if dtype_ == np.float64 else float32_tol_fudge * 0.04
+      tol = 1e-6 if dtype_ == np.float64 else float32_tol_fudge * 0.05
       # The gradients for a and b may be of very different magnitudes,
       # so to not get spurious failures we test them separately.
       for factor, factor_init in [a, a_np], [b, b_np]:
diff --git a/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py b/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py
new file mode 100644
index 0000000..252090b
--- /dev/null
+++ b/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py
@@ -0,0 +1,70 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.kernels.logging_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+class PrintV2LoggingLevelTest(test.TestCase):
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintOneTensorLogInfo(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(
+            tensor, output_stream=tf_logging.info)
+        self.evaluate(print_op)
+      self.assertTrue("I" in printed.contents())
+      expected = "[0 1 2 ... 7 8 9]"
+      self.assertTrue(expected in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintOneTensorLogWarning(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(
+            tensor, output_stream=tf_logging.warning)
+        self.evaluate(print_op)
+      self.assertTrue("W" in printed.contents())
+      expected = "[0 1 2 ... 7 8 9]"
+      self.assertTrue(expected in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintOneTensorLogError(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(
+            tensor, output_stream=tf_logging.error)
+        self.evaluate(print_op)
+      self.assertTrue("E" in printed.contents())
+      expected = "[0 1 2 ... 7 8 9]"
+      self.assertTrue(expected in printed.contents())
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index 82729b9..cf0beba 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -18,14 +18,23 @@
 from __future__ import division
 from __future__ import print_function
 
+import sys
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import logging_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
 
 
 class LoggingOpsTest(test.TestCase):
@@ -57,6 +66,302 @@
         out.eval()
 
 
+class PrintV2Test(test.TestCase):
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintOneTensor(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(tensor)
+        self.evaluate(print_op)
+
+      expected = "[0 1 2 ... 7 8 9]"
+      self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintOneTensorVarySummarize(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(tensor, summarize=1)
+        self.evaluate(print_op)
+
+      expected = "[0 ... 9]"
+      self.assertTrue((expected + "\n") in printed.contents())
+
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(tensor, summarize=2)
+        self.evaluate(print_op)
+
+      expected = "[0 1 ... 8 9]"
+      self.assertTrue((expected + "\n") in printed.contents())
+
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(tensor, summarize=3)
+        self.evaluate(print_op)
+
+      expected = "[0 1 2 ... 7 8 9]"
+      self.assertTrue((expected + "\n") in printed.contents())
+
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(tensor, summarize=-1)
+        self.evaluate(print_op)
+
+      expected = "[0 1 2 3 4 5 6 7 8 9]"
+      self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintOneVariable(self):
+    with self.test_session():
+      var = variables.Variable(math_ops.range(10))
+      if not context.executing_eagerly():
+        variables.global_variables_initializer().run()
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(var)
+        self.evaluate(print_op)
+      expected = "[0 1 2 ... 7 8 9]"
+      self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintTwoVariablesInStructWithAssignAdd(self):
+    with self.test_session():
+      var_one = variables.Variable(2.14)
+      plus_one = var_one.assign_add(1.0)
+      var_two = variables.Variable(math_ops.range(10))
+      if not context.executing_eagerly():
+        variables.global_variables_initializer().run()
+      with self.captureWritesToStream(sys.stderr) as printed:
+        self.evaluate(plus_one)
+        print_op = logging_ops.print_v2(var_one, {"second": var_two})
+        self.evaluate(print_op)
+      expected = "3.14 {'second': [0 1 2 ... 7 8 9]}"
+      self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintTwoTensors(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(tensor, tensor * 10)
+        self.evaluate(print_op)
+      expected = "[0 1 2 ... 7 8 9] [0 10 20 ... 70 80 90]"
+      self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintPlaceholderGeneration(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10})
+        self.evaluate(print_op)
+      expected = "{}6 {'{}': [0 10 20 ... 70 80 90]}"
+      self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintNoTensors(self):
+    with self.test_session():
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(23, [23, 5], {"6": 12})
+        self.evaluate(print_op)
+      expected = "23 [23, 5] {'6': 12}"
+      self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintFloatScalar(self):
+    with self.test_session():
+      tensor = ops.convert_to_tensor(434.43)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(tensor)
+        self.evaluate(print_op)
+      expected = "434.43"
+      self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintStringScalar(self):
+    with self.test_session():
+      tensor = ops.convert_to_tensor("scalar")
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(tensor)
+        self.evaluate(print_op)
+      expected = "scalar"
+      self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintComplexTensorStruct(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      small_tensor = constant_op.constant([0.3, 12.4, -16.1])
+      big_tensor = math_ops.mul(tensor, 10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(
+            "first:", tensor, "middle:",
+            {"small": small_tensor, "Big": big_tensor}, 10,
+            [tensor * 2, tensor])
+        self.evaluate(print_op)
+      # Note that the keys in the dict will always be sorted,
+      # so 'Big' comes before 'small'
+      expected = ("first: [0 1 2 ... 7 8 9] "
+                  "middle: {'Big': [0 10 20 ... 70 80 90], "
+                  "'small': [0.3 12.4 -16.1]} "
+                  "10 [[0 2 4 ... 14 16 18], [0 1 2 ... 7 8 9]]")
+      self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintSparseTensor(self):
+    with self.test_session():
+      ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
+      val = [0, 10, 13, 4, 14, 32, 33]
+      shape = [5, 6]
+
+      sparse = sparse_tensor.SparseTensor(
+          constant_op.constant(ind, dtypes.int64),
+          constant_op.constant(val, dtypes.int64),
+          constant_op.constant(shape, dtypes.int64))
+
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(sparse)
+        self.evaluate(print_op)
+      expected = ("'SparseTensor(indices=[[0 0]\n"
+                  " [1 0]\n"
+                  " [1 3]\n"
+                  " ...\n"
+                  " [1 4]\n"
+                  " [3 2]\n"
+                  " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])'")
+      self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintSparseTensorInDataStruct(self):
+    with self.test_session():
+      ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
+      val = [0, 10, 13, 4, 14, 32, 33]
+      shape = [5, 6]
+
+      sparse = sparse_tensor.SparseTensor(
+          constant_op.constant(ind, dtypes.int64),
+          constant_op.constant(val, dtypes.int64),
+          constant_op.constant(shape, dtypes.int64))
+
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2([sparse])
+        self.evaluate(print_op)
+      expected = ("['SparseTensor(indices=[[0 0]\n"
+                  " [1 0]\n"
+                  " [1 3]\n"
+                  " ...\n"
+                  " [1 4]\n"
+                  " [3 2]\n"
+                  " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])']")
+      self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintOneTensorStdout(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stdout) as printed:
+        print_op = logging_ops.print_v2(
+            tensor, output_stream=sys.stdout)
+        self.evaluate(print_op)
+      expected = "[0 1 2 ... 7 8 9]"
+      self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintOneTensorLogInfo(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(
+            tensor, output_stream=tf_logging.info)
+        self.evaluate(print_op)
+      expected = "[0 1 2 ... 7 8 9]"
+      self.assertTrue(expected in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintOneTensorLogWarning(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(
+            tensor, output_stream=tf_logging.warning)
+        self.evaluate(print_op)
+      expected = "[0 1 2 ... 7 8 9]"
+      self.assertTrue(expected in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintOneTensorLogError(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.captureWritesToStream(sys.stderr) as printed:
+        print_op = logging_ops.print_v2(
+            tensor, output_stream=tf_logging.error)
+        self.evaluate(print_op)
+      expected = "[0 1 2 ... 7 8 9]"
+      self.assertTrue(expected in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testInvalidOutputStreamRaisesError(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      with self.assertRaises(ValueError):
+        print_op = logging_ops.print_v2(
+            tensor, output_stream="unknown")
+        self.evaluate(print_op)
+
+  def testPrintOpName(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      print_op = logging_ops.print_v2(tensor, name="print_name")
+      self.assertEqual(print_op.name, "print_name")
+
+  def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      formatted_string = string_ops.string_format("{}", tensor)
+      print_op = logging_ops.print_v2(formatted_string)
+      self.evaluate(print_op)
+      graph_ops = ops.get_default_graph().get_operations()
+      format_ops = [op for op in graph_ops if op.type == "StringFormat"]
+      # Should be only 1 format_op for graph mode.
+      self.assertEqual(len(format_ops), 1)
+
+  def testPrintOneTensorEagerOnOpCreate(self):
+    with self.test_session():
+      with context.eager_mode():
+        tensor = math_ops.range(10)
+        expected = "[0 1 2 ... 7 8 9]"
+        with self.captureWritesToStream(sys.stderr) as printed:
+          logging_ops.print_v2(tensor)
+        self.assertTrue((expected + "\n") in printed.contents())
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testPrintInDefunWithoutExplicitEvalOfPrint(self):
+    @function.defun
+    def f():
+      tensor = math_ops.range(10)
+      logging_ops.print_v2(tensor)
+      return tensor
+
+    expected = "[0 1 2 ... 7 8 9]"
+    with self.captureWritesToStream(sys.stderr) as printed_one:
+      x = f()
+      self.evaluate(x)
+    self.assertTrue((expected + "\n") in printed_one.contents())
+
+    # We execute the function again to make sure it doesn't only print on the
+    # first call.
+    with self.captureWritesToStream(sys.stderr) as printed_two:
+      y = f()
+      self.evaluate(y)
+    self.assertTrue((expected + "\n") in printed_two.contents())
+
+
 class PrintGradientTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
@@ -65,6 +370,11 @@
     inp_printed = logging_ops.Print(inp, [inp])
     self.assertEqual(inp.get_shape(), inp_printed.get_shape())
 
+  def testPrintString(self):
+    inp = constant_op.constant(2.0, shape=[100, 32])
+    inp_printed = logging_ops.Print(inp, ["hello"])
+    self.assertEqual(inp.get_shape(), inp_printed.get_shape())
+
   def testPrintGradient(self):
     with self.cached_session():
       inp = constant_op.constant(2.0, shape=[100, 32], name="in")
diff --git a/tensorflow/python/kernel_tests/regex_full_match_op_test.py b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
index e81f562..98746e7 100644
--- a/tensorflow/python/kernel_tests/regex_full_match_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
@@ -42,7 +42,7 @@
 
   def testRegexFullMatchTwoDims(self, op):
     values = [["abaaba", "abcdabcde"], ["acdcba", "ebcda"]]
-    with self.test_session():
+    with self.cached_session():
       input_tensor = constant_op.constant(values, dtypes.string)
       matched = op(input_tensor, "a.*a").eval()
       self.assertAllEqual([[True, False], [True, False]], matched)
@@ -68,7 +68,7 @@
 
   def testRegexFullMatchDelegation(self):
     with compat.forward_compatibility_horizon(2018, 11, 1):
-      with self.test_session():
+      with self.cached_session():
         input_tensor = constant_op.constant("foo", dtypes.string)
         pattern = "[a-z]"
         op = string_ops.regex_full_match(input_tensor, pattern)
@@ -80,7 +80,7 @@
 
   def testStaticRegexFullMatchDelegation(self):
     with compat.forward_compatibility_horizon(2018, 11, 20):
-      with self.test_session():
+      with self.cached_session():
         input_tensor = constant_op.constant("foo", dtypes.string)
         pattern = "[a-z]*"
         op = string_ops.regex_full_match(input_tensor, pattern)
diff --git a/tensorflow/python/kernel_tests/regex_replace_op_test.py b/tensorflow/python/kernel_tests/regex_replace_op_test.py
index feac3a8..d9b7ed2 100644
--- a/tensorflow/python/kernel_tests/regex_replace_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_replace_op_test.py
@@ -33,7 +33,7 @@
 class RegexReplaceOpVariantsTest(test.TestCase, parameterized.TestCase):
 
   def testForwarding(self, op):
-    with self.test_session():
+    with self.cached_session():
       # Generate an input that is uniquely consumed by the regex op.
       # This exercises code paths which are optimized for this case
       # (e.g., using forwarding).
@@ -47,7 +47,7 @@
 
   def testRemovePrefix(self, op):
     values = ["a:foo", "a:bar", "a:foo", "b:baz", "b:qux", "ca:b"]
-    with self.test_session():
+    with self.cached_session():
       input_vector = constant_op.constant(values, dtypes.string)
       stripped = op(input_vector, "^(a:|b:)", "", replace_global=False).eval()
       self.assertAllEqual([b"foo", b"bar", b"foo", b"baz", b"qux", b"ca:b"],
@@ -55,21 +55,21 @@
 
   def testRegexReplace(self, op):
     values = ["aba\naba", "abcdabcde"]
-    with self.test_session():
+    with self.cached_session():
       input_vector = constant_op.constant(values, dtypes.string)
       stripped = op(input_vector, "a.*a", "(\\0)").eval()
       self.assertAllEqual([b"(aba)\n(aba)", b"(abcda)bcde"], stripped)
 
   def testEmptyMatch(self, op):
     values = ["abc", "1"]
-    with self.test_session():
+    with self.cached_session():
       input_vector = constant_op.constant(values, dtypes.string)
       stripped = op(input_vector, "", "x").eval()
       self.assertAllEqual([b"xaxbxcx", b"x1x"], stripped)
 
   def testInvalidPattern(self, op):
     values = ["abc", "1"]
-    with self.test_session():
+    with self.cached_session():
       input_vector = constant_op.constant(values, dtypes.string)
       invalid_pattern = "A["
       replace = op(input_vector, invalid_pattern, "x")
@@ -78,7 +78,7 @@
 
   def testGlobal(self, op):
     values = ["ababababab", "abcabcabc", ""]
-    with self.test_session():
+    with self.cached_session():
       input_vector = constant_op.constant(values, dtypes.string)
       stripped = op(input_vector, "ab", "abc", True).eval()
       self.assertAllEqual([b"abcabcabcabcabc", b"abccabccabcc", b""], stripped)
@@ -99,7 +99,7 @@
       (as_tensor, as_string),
       (as_tensor, as_tensor))
   def testRegexReplaceDelegation(self, pattern_fn, rewrite_fn):
-    with self.test_session():
+    with self.cached_session():
       input_vector = constant_op.constant("foo", dtypes.string)
       pattern = pattern_fn("[a-z]")
       replace = rewrite_fn(".")
@@ -107,7 +107,7 @@
       self.assertTrue(op.name.startswith("RegexReplace"))
 
   def testStaticRegexReplaceDelegation(self):
-    with self.test_session():
+    with self.cached_session():
       input_vector = constant_op.constant("foo", dtypes.string)
       pattern = "[a-z]"
       replace = "."
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index ce507e4..2931877 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -300,7 +300,7 @@
               tf_ans = s.eval()
               if dtype is dtypes_lib.bfloat16:
                 tf_ans = tf_ans.astype(np.float32)
-              self.assertAllClose(np_ans, tf_ans)
+              self.assertAllCloseAccordingToType(np_ans, tf_ans)
               self.assertShapeEqual(np_ans, s)
 
   def testNumSegmentsTypes(self):
diff --git a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
index 4777203..a824d5c 100644
--- a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
@@ -195,7 +195,7 @@
       self.assertAllEqual([-1, 2], val.dense_shape)
 
   def testAccumulatorTakeGradSum(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=(), reduction_type="SUM")
 
@@ -289,7 +289,7 @@
           val, sess)
 
   def testParallelApplyGradSum(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32,
           name="Q",
diff --git a/tensorflow/python/kernel_tests/string_format_op_test.py b/tensorflow/python/kernel_tests/string_format_op_test.py
new file mode 100644
index 0000000..afa71db
--- /dev/null
+++ b/tensorflow/python/kernel_tests/string_format_op_test.py
@@ -0,0 +1,384 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.kernels.logging_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class StringFormatOpTest(test.TestCase):
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneTensorOneDim(self):
+    with self.test_session():
+      tensor = math_ops.range(10)
+      format_output = string_ops.string_format("{}", tensor)
+      out = self.evaluate(format_output)
+      expected = "[0 1 2 ... 7 8 9]"
+      self.assertEqual(compat.as_text(out), expected)
+
+    with self.test_session():
+      tensor = math_ops.range(10)
+      format_output = string_ops.string_format("{}", [tensor])
+      out = self.evaluate(format_output)
+      expected = "[0 1 2 ... 7 8 9]"
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneVariableScalar(self):
+    with self.test_session():
+      var = variables.Variable(3.34)
+      format_output = string_ops.string_format("{}", [var])
+      if not context.executing_eagerly():
+        variables.global_variables_initializer().run()
+      out = self.evaluate(format_output)
+      expected = "3.34"
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneVariableOneDim(self):
+    with self.test_session():
+      var = variables.Variable(math_ops.range(10))
+      format_output = string_ops.string_format("{}", [var])
+      if not context.executing_eagerly():
+        variables.global_variables_initializer().run()
+      out = self.evaluate(format_output)
+      expected = "[0 1 2 ... 7 8 9]"
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatTwoVariablesWithAssignAdd(self):
+    with self.test_session():
+      var_one = variables.Variable(2.14)
+      plus_one = var_one.assign_add(1.0)
+      var_two = variables.Variable(math_ops.range(10))
+      format_output = string_ops.string_format("{}, {}", [var_one, var_two])
+      if not context.executing_eagerly():
+        variables.global_variables_initializer().run()
+      self.evaluate(plus_one)
+      out = self.evaluate(format_output)
+      expected = "3.14, [0 1 2 ... 7 8 9]"
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneTensorOneDimFloat(self):
+    with self.test_session():
+      tensor = constant_op.constant([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
+      format_output = string_ops.string_format("{}", tensor)
+      out = self.evaluate(format_output)
+      expected = "[0 0.1 0.2 ... 0.5 0.6 0.7]"
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneTensorOneDimMatchesSummarize(self):
+    with self.test_session():
+      tensor = math_ops.range(6)
+      format_output = string_ops.string_format("{}", tensor, summarize=3)
+      out = self.evaluate(format_output)
+      expected = "[0 1 2 3 4 5]"
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneTensorOneDimVarySummarize(self):
+    with self.test_session():
+      tensor = math_ops.range(6)
+      format_output = string_ops.string_format("{}", tensor, summarize=-1)
+      out = self.evaluate(format_output)
+      expected = "[0 1 2 3 4 5]"
+      self.assertEqual(compat.as_text(out), expected)
+
+    with self.test_session():
+      tensor = math_ops.range(6)
+      format_output = string_ops.string_format("{}", tensor, summarize=1)
+      out = self.evaluate(format_output)
+      expected = "[0 ... 5]"
+      self.assertEqual(compat.as_text(out), expected)
+
+    with self.test_session():
+      tensor = math_ops.range(6)
+      format_output = string_ops.string_format("{}", tensor, summarize=2)
+      out = self.evaluate(format_output)
+      expected = "[0 1 ... 4 5]"
+      self.assertEqual(compat.as_text(out), expected)
+
+    with self.test_session():
+      tensor = math_ops.range(6)
+      format_output = string_ops.string_format("{}", tensor, summarize=10)
+      out = self.evaluate(format_output)
+      expected = "[0 1 2 3 4 5]"
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneTensorOneDimAlmostSummarize(self):
+    with self.test_session():
+      tensor = math_ops.range(5)
+      format_output = string_ops.string_format("{}", tensor, summarize=3)
+      out = self.evaluate(format_output)
+      expected = "[0 1 2 3 4]"
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneTensorTwoDimLessThanSummarize(self):
+    with self.test_session():
+      tensor = array_ops.reshape(math_ops.range(4), [2, 2])
+      format_output = string_ops.string_format("{}", tensor, summarize=3)
+      out = self.evaluate(format_output)
+      expected = ("[[0 1]\n"
+                  " [2 3]]")
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneTensorTwoDim(self):
+    with self.test_session():
+      tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+      format_output = string_ops.string_format("{}", tensor)
+      out = self.evaluate(format_output)
+      expected = ("[[0 1 2 ... 7 8 9]\n"
+                  " [10 11 12 ... 17 18 19]\n"
+                  " [20 21 22 ... 27 28 29]\n"
+                  " ...\n"
+                  " [70 71 72 ... 77 78 79]\n"
+                  " [80 81 82 ... 87 88 89]\n"
+                  " [90 91 92 ... 97 98 99]]")
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneTensorTwoDimSummarizeTwo(self):
+    with self.test_session():
+      tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+      format_output = string_ops.string_format("{}", tensor, summarize=2)
+      out = self.evaluate(format_output)
+      expected = ("[[0 1 ... 8 9]\n"
+                  " [10 11 ... 18 19]\n"
+                  " ...\n"
+                  " [80 81 ... 88 89]\n"
+                  " [90 91 ... 98 99]]")
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneTensorThreeDim(self):
+    with self.test_session():
+      tensor = array_ops.reshape(math_ops.range(1000), [10, 10, 10])
+      format_output = string_ops.string_format("{}", tensor)
+      out = self.evaluate(format_output)
+      expected = ("[[[0 1 2 ... 7 8 9]\n"
+                  "  [10 11 12 ... 17 18 19]\n"
+                  "  [20 21 22 ... 27 28 29]\n"
+                  "  ...\n"
+                  "  [70 71 72 ... 77 78 79]\n"
+                  "  [80 81 82 ... 87 88 89]\n"
+                  "  [90 91 92 ... 97 98 99]]\n"
+                  "\n"
+                  " [[100 101 102 ... 107 108 109]\n"
+                  "  [110 111 112 ... 117 118 119]\n"
+                  "  [120 121 122 ... 127 128 129]\n"
+                  "  ...\n  [170 171 172 ... 177 178 179]\n"
+                  "  [180 181 182 ... 187 188 189]\n"
+                  "  [190 191 192 ... 197 198 199]]\n"
+                  "\n"
+                  " [[200 201 202 ... 207 208 209]\n"
+                  "  [210 211 212 ... 217 218 219]\n"
+                  "  [220 221 222 ... 227 228 229]\n"
+                  "  ...\n"
+                  "  [270 271 272 ... 277 278 279]\n"
+                  "  [280 281 282 ... 287 288 289]\n"
+                  "  [290 291 292 ... 297 298 299]]\n"
+                  "\n"
+                  " ...\n"
+                  "\n"
+                  " [[700 701 702 ... 707 708 709]\n"
+                  "  [710 711 712 ... 717 718 719]\n"
+                  "  [720 721 722 ... 727 728 729]\n"
+                  "  ...\n"
+                  "  [770 771 772 ... 777 778 779]\n"
+                  "  [780 781 782 ... 787 788 789]\n"
+                  "  [790 791 792 ... 797 798 799]]\n"
+                  "\n"
+                  " [[800 801 802 ... 807 808 809]\n"
+                  "  [810 811 812 ... 817 818 819]\n"
+                  "  [820 821 822 ... 827 828 829]\n"
+                  "  ...\n"
+                  "  [870 871 872 ... 877 878 879]\n"
+                  "  [880 881 882 ... 887 888 889]\n"
+                  "  [890 891 892 ... 897 898 899]]\n"
+                  "\n"
+                  " [[900 901 902 ... 907 908 909]\n"
+                  "  [910 911 912 ... 917 918 919]\n"
+                  "  [920 921 922 ... 927 928 929]\n"
+                  "  ...\n"
+                  "  [970 971 972 ... 977 978 979]\n"
+                  "  [980 981 982 ... 987 988 989]\n"
+                  "  [990 991 992 ... 997 998 999]]]")
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneTensorTemplatePrefix(self):
+    with self.test_session():
+      tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+      format_output = string_ops.string_format("tensor summary: {}", tensor)
+      out = self.evaluate(format_output)
+      expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+                  " [10 11 12 ... 17 18 19]\n"
+                  " [20 21 22 ... 27 28 29]\n"
+                  " ...\n"
+                  " [70 71 72 ... 77 78 79]\n"
+                  " [80 81 82 ... 87 88 89]\n"
+                  " [90 91 92 ... 97 98 99]]")
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneTensorTemplatePrefixAndSuffix(self):
+    with self.test_session():
+      tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+      format_output = string_ops.string_format("tensor summary: {}, suffix",
+                                               tensor)
+      out = self.evaluate(format_output)
+      expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+                  " [10 11 12 ... 17 18 19]\n"
+                  " [20 21 22 ... 27 28 29]\n"
+                  " ...\n"
+                  " [70 71 72 ... 77 78 79]\n"
+                  " [80 81 82 ... 87 88 89]\n"
+                  " [90 91 92 ... 97 98 99]], suffix")
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatOneTensorTemplateSuffix(self):
+    with self.test_session():
+      tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+      format_output = string_ops.string_format("{}, suffix", tensor)
+      out = self.evaluate(format_output)
+      expected = ("[[0 1 2 ... 7 8 9]\n"
+                  " [10 11 12 ... 17 18 19]\n"
+                  " [20 21 22 ... 27 28 29]\n"
+                  " ...\n"
+                  " [70 71 72 ... 77 78 79]\n"
+                  " [80 81 82 ... 87 88 89]\n"
+                  " [90 91 92 ... 97 98 99]], suffix")
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatNoTensor(self):
+    with self.test_session():
+      format_output = string_ops.string_format("No tensor.", ())
+      out = self.evaluate(format_output)
+      expected = "No tensor."
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatMultiTensor(self):
+    with self.test_session():
+      tensor_one = array_ops.reshape(math_ops.range(100), [10, 10])
+      tensor_two = tensor_one * 10
+      format_output = string_ops.string_format("One: {},\nTwo: {}",
+                                               (tensor_one, tensor_two))
+      out = self.evaluate(format_output)
+      expected = ("One: [[0 1 2 ... 7 8 9]\n"
+                  " [10 11 12 ... 17 18 19]\n"
+                  " [20 21 22 ... 27 28 29]\n"
+                  " ...\n"
+                  " [70 71 72 ... 77 78 79]\n"
+                  " [80 81 82 ... 87 88 89]\n"
+                  " [90 91 92 ... 97 98 99]],\n"
+                  "Two: [[0 10 20 ... 70 80 90]\n"
+                  " [100 110 120 ... 170 180 190]\n"
+                  " [200 210 220 ... 270 280 290]\n"
+                  " ...\n"
+                  " [700 710 720 ... 770 780 790]\n"
+                  " [800 810 820 ... 870 880 890]\n"
+                  " [900 910 920 ... 970 980 990]]")
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatSummarizeOne(self):
+    with self.test_session():
+      tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+      format_output = string_ops.string_format("tensor summary: {}", tensor,
+                                               summarize=1)
+      out = self.evaluate(format_output)
+      expected = ("tensor summary: [[0 ... 9]\n"
+                  " ...\n"
+                  " [90 ... 99]]")
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatSummarizeTwo(self):
+    with self.test_session():
+      tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+      format_output = string_ops.string_format("tensor summary: {}", tensor,
+                                               summarize=2)
+      out = self.evaluate(format_output)
+      expected = ("tensor summary: [[0 1 ... 8 9]\n"
+                  " [10 11 ... 18 19]\n"
+                  " ...\n"
+                  " [80 81 ... 88 89]\n"
+                  " [90 91 ... 98 99]]")
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testFormatPlaceholder(self):
+    with self.test_session():
+      tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+      format_output = string_ops.string_format("tensor summary: %t%", tensor,
+                                               placeholder="%t%")
+      out = self.evaluate(format_output)
+      expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+                  " [10 11 12 ... 17 18 19]\n"
+                  " [20 21 22 ... 27 28 29]\n"
+                  " ...\n"
+                  " [70 71 72 ... 77 78 79]\n"
+                  " [80 81 82 ... 87 88 89]\n"
+                  " [90 91 92 ... 97 98 99]]")
+      self.assertEqual(compat.as_text(out), expected)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testTensorCountMustMatchPlaceholderCount(self):
+    with self.test_session():
+      with self.assertRaisesRegexp(
+          ValueError, r"2 placeholder\(s\) in template does not match 1 "
+                      r"tensor\(s\) provided as input"):
+        tensor = math_ops.range(10)
+        format_output = string_ops.string_format("{} {}", tensor)
+        self.evaluate(format_output)
+    with self.test_session():
+      with self.assertRaisesRegexp(
+          ValueError, r"2 placeholder\(s\) in template does not match 1 "
+                      r"tensor\(s\) provided as input"):
+        tensor = math_ops.range(10)
+        format_output = string_ops.string_format("{} {}", [tensor])
+        self.evaluate(format_output)
+    with self.test_session():
+      with self.assertRaisesRegexp(
+          ValueError, r"1 placeholder\(s\) in template does not match 2 "
+                      r"tensor\(s\) provided as input"):
+        tensor = math_ops.range(10)
+        format_output = string_ops.string_format("{}", (tensor, tensor))
+        self.evaluate(format_output)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py
index 4d163a0..cd3fe14 100644
--- a/tensorflow/python/kernel_tests/substr_op_test.py
+++ b/tensorflow/python/kernel_tests/substr_op_test.py
@@ -46,7 +46,7 @@
     expected_value = b"ell"
 
     substr_op = string_ops.substr(test_string, position, length)
-    with self.test_session():
+    with self.cached_session():
       substr = substr_op.eval()
       self.assertAllEqual(substr, expected_value)
 
@@ -57,7 +57,7 @@
     expected_value = b""
 
     substr_op = string_ops.substr(test_string, position, length)
-    with self.test_session():
+    with self.cached_session():
       substr = substr_op.eval()
       self.assertAllEqual(substr, expected_value)
 
@@ -79,7 +79,7 @@
     expected_value = [b"ell", b"orl"]
 
     substr_op = string_ops.substr(test_string, position, length)
-    with self.test_session():
+    with self.cached_session():
       substr = substr_op.eval()
       self.assertAllEqual(substr, expected_value)
 
@@ -104,7 +104,7 @@
                       [b"ixte", b"even", b"ight"]]
 
     substr_op = string_ops.substr(test_string, position, length)
-    with self.test_session():
+    with self.cached_session():
       substr = substr_op.eval()
       self.assertAllEqual(substr, expected_value)
 
@@ -196,7 +196,7 @@
     position = np.array(-7, dtype)
     length = np.array(3, dtype)
     substr_op = string_ops.substr(test_string, position, length)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(errors_impl.InvalidArgumentError):
         substr = substr_op.eval()
 
@@ -234,7 +234,7 @@
     position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype)
     length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
     substr_op = string_ops.substr(test_string, position, length)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(errors_impl.InvalidArgumentError):
         substr = substr_op.eval()
 
@@ -252,7 +252,7 @@
     position = np.array([-1, -2, -4], dtype)
     length = np.array([1, 2, 3], dtype)
     substr_op = string_ops.substr(test_string, position, length)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(errors_impl.InvalidArgumentError):
         substr = substr_op.eval()
 
diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py
new file mode 100644
index 0000000..0c3b724
--- /dev/null
+++ b/tensorflow/python/kernel_tests/while_v2_test.py
@@ -0,0 +1,276 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for while_v2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import while_v2
+from tensorflow.python.ops.control_flow_ops import while_loop as while_loop_v1
+from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
+from tensorflow.python.platform import test
+
+
+class WhileV2Test(test.TestCase, parameterized.TestCase):
+
+  def testSingleLoopVar(self):
+    x = constant_op.constant(2.)
+    ret = while_loop_v2(lambda v: v < 8., lambda v: v * v, [x])
+    grad = gradients_impl.gradients(ret, [x])
+    with self.test_session() as sess:
+      self.assertEqual(sess.run(ret), 16.)
+      self.assertSequenceEqual(sess.run(grad), [32.])
+
+  def testMultipleLoopVarsBasic(self):
+    x = constant_op.constant(5.)
+    y = constant_op.constant(3.)
+
+    # x = 5.
+    # y = 3.
+    # while x < 45.:
+    #   x = x * y
+    ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, w), [x, y])
+    # ret = [x*y^2, y]
+
+    # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
+    grad = gradients_impl.gradients(ret, [x])  # [2*x*y]
+    with self.test_session() as sess:
+      self.assertSequenceEqual(sess.run(ret), [45., 3.])
+      self.assertSequenceEqual(sess.run(grad), [9.])
+
+  def testMultipleLoopVars(self):
+    x = constant_op.constant(5.)
+    y = constant_op.constant(3.)
+
+    # x = 5.
+    # y = 3.
+    # while x < 45.:
+    #   x = x * y
+    #   y = x + y
+    ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, v + w),
+                        [x, y])
+    # ret = [y*x**2 + x*y**2, x*y + x + y]
+
+    gradx_0 = gradients_impl.gradients(ret[0], [x])  # [2*x*y + y**2]
+    gradx_1 = gradients_impl.gradients(ret[1], [x])  # [y + 1]
+    gradx_2 = gradients_impl.gradients(ret, [x])  # [2*x*y + y**2 + 2*y + 1]
+    grady_0 = gradients_impl.gradients(ret[0], [y])  # [2*x*y + x**2]
+    grady_1 = gradients_impl.gradients(ret[1], [y])  # [x + 1]
+    grady_2 = gradients_impl.gradients(ret, [y])  # [2*x*y + x**2 + x + 1]
+    with self.test_session() as sess:
+      self.assertSequenceEqual(sess.run(ret), [120., 23.])
+      self.assertSequenceEqual(sess.run(gradx_0), [39.])
+      self.assertSequenceEqual(sess.run(gradx_1), [4.])
+      self.assertSequenceEqual(sess.run(gradx_2), [43.])
+      self.assertSequenceEqual(sess.run(grady_0), [55.])
+      self.assertSequenceEqual(sess.run(grady_1), [6.])
+      self.assertSequenceEqual(sess.run(grady_2), [61.])
+
+  def testMultipleWhileLoops(self):
+    x = constant_op.constant(2.)
+    ret1 = while_loop_v2(lambda v: v < 4., lambda v: v * v, [x])  # x**2
+    ret2 = while_loop_v2(lambda v: v < 16., lambda v: v * v, ret1)  # x**4
+    grad = gradients_impl.gradients(ret2, [x])  # 4x**3
+    grad_grad = gradients_impl.gradients(grad, [x])  # 12x**2
+    with self.test_session() as sess:
+      self.assertSequenceEqual(sess.run(grad), [32.])
+      self.assertSequenceEqual(sess.run(grad_grad), [48.])
+
+  def testDoubleDerivative(self):
+    x = constant_op.constant(2.)
+    ret = while_loop_v2(lambda v: v < 8., lambda v: v**2, [x])  # x**4
+    grad = gradients_impl.gradients(ret, [x])  # 4x**3
+    grad_grad = gradients_impl.gradients(grad, [x])  # 12x**2
+    with self.test_session() as sess:
+      self.assertEqual(sess.run(ret), 16.)
+      self.assertSequenceEqual(sess.run(grad), [32.])
+      self.assertSequenceEqual(sess.run(grad_grad), [48.])
+
+  def testPruning(self):
+    x = constant_op.constant(1)
+
+    tensor_list = list_ops.empty_tensor_list(
+        element_dtype=x.dtype, element_shape=x.shape)
+
+    def Cond(x, tl):
+      del tl  # Unused for Cond.
+      return x < 5
+
+    def Body(x, tl):
+      return x + 1, list_ops.tensor_list_push_back(tl, x)
+
+    outputs = while_loop_v1(Cond, Body, [x, tensor_list])
+
+    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+    train_op.append(outputs[0])
+
+    def GetOptimizedGraph():
+      mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
+      rewriter_config = rewriter_config_pb2.RewriterConfig(
+          constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
+          memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
+      return tf_optimizer.OptimizeGraph(rewriter_config, mg)
+
+    g = GetOptimizedGraph()
+    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)
+
+    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
+    train_op.append(stack)
+    g = GetOptimizedGraph()
+    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
+
+  def testCaptureExternalTensorInCond(self):
+    x = constant_op.constant(2.)
+    y = constant_op.constant(1.)
+    ret = while_loop_v2(lambda v: v + y < 9., lambda v: v * 3., [x])
+    grad = gradients_impl.gradients(ret, [x])
+    with self.test_session() as sess:
+      self.assertEqual(sess.run(ret), 18.)
+      self.assertSequenceEqual(sess.run(grad), [9.])
+
+  def testCaptureExternalTensorInBody(self):
+    x = constant_op.constant(2.)
+    y = constant_op.constant(3.)
+    ret = while_loop_v2(lambda v: v < 8., lambda v: v * y, [x])
+    grad = gradients_impl.gradients(ret, [x])
+    with self.test_session() as sess:
+      self.assertEqual(sess.run(ret), 18.)
+      self.assertSequenceEqual(sess.run(grad), [9.])
+
+  def testLoopWithTensorListPushBack(self):
+    x = constant_op.constant(2.)
+
+    tensor_list = list_ops.empty_tensor_list(
+        element_dtype=dtypes.float32, element_shape=ScalarShape())
+
+    def Cond(x, tl):
+      del tl  # Unused for Cond.
+      return x < 5.
+
+    def Body(x, tl):
+      tl = list_ops.tensor_list_push_back(tl, x)
+      tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.))
+      return x**2., tl
+
+    ret = while_loop_v2(Cond, Body, [x, tensor_list])
+    grad = gradients_impl.gradients(ret[0], x)
+    with self.test_session() as sess:
+      self.assertEqual(sess.run(ret[0]), 16.)
+      self.assertSequenceEqual(sess.run(grad), [32.])
+
+  def testDuplicateAccumulator(self):
+    x = constant_op.constant(2.)
+
+    tensor_list = list_ops.empty_tensor_list(
+        element_dtype=dtypes.float32, element_shape=ScalarShape())
+
+    def Cond(x, tl):
+      del tl  # Unused for Cond.
+      return x < 5.
+
+    def Body(x, tl):
+      # There is an accumulator in the loop already so we should not add
+      # another.
+      tl = list_ops.tensor_list_push_back(tl, x)
+      return x**2., tl
+
+    ret = while_loop_v2(Cond, Body, [x, tensor_list])
+
+    for op in ops.get_default_graph().get_operations():
+      if op.type == "While":
+        while_op = op
+
+    body_graph = while_v2._get_body_graph(while_op)
+    # body_graph.inputs: [counter_arg, x_arg, tl_arg, *accumulators]
+    x_input_t = body_graph.inputs[1]
+    accumulator_count = len(
+        [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
+    self.assertEqual(accumulator_count, 1)
+
+    grad = gradients_impl.gradients(ret[0], x)
+    with self.test_session() as sess:
+      self.assertEqual(sess.run(ret[0]), 16.)
+      self.assertSequenceEqual(sess.run(grad), [32.])
+
+  @parameterized.named_parameters(
+      ("UnknownShape", None),
+      ("PartiallyDefinedShape", [None, 2]),
+      ("FullyDefinedShape", [1, 2]),
+  )
+  def testTensorListOutputElementShape(self, shape):
+
+    def MatchShape(actual_tensor_shape):
+      # Compare the shapes, treating None dimensions as equal. We do not
+      # directly check actual_tensor_shape and tf.TensorShape(shape) for
+      # equality because tf.Dimension.__eq__ returns None if either dimension is
+      # None.
+      if shape is None:
+        self.assertIsNone(actual_tensor_shape.dims)
+      else:
+        self.assertListEqual(actual_tensor_shape.as_list(), shape)
+
+    def GetAccumulatorForInputAtIndex(while_op, idx):
+      body_graph = while_v2._get_body_graph(while_op)
+      y_input_t = body_graph.inputs[idx]
+      push_back_node = [c for c in y_input_t.consumers()
+                        if c.type == "TensorListPushBack"][0]
+      output_idx = body_graph.outputs.index(push_back_node.outputs[0])
+      return while_op.outputs[output_idx]
+
+    x = constant_op.constant(2.)
+    y = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
+
+    # Forward pass.
+    ret = while_loop_v2(lambda v, u: v < 8., lambda v, u: (v * v, u), [x, y])
+    while_op = ret[0].op
+    # Get the TensorList output of While op containing the accumulated values
+    # of y.
+    # while_op.inputs: [counter_arg, x_arg, y_arg, *accumulators]
+    output = GetAccumulatorForInputAtIndex(while_op, 2)
+    _, val = list_ops.tensor_list_pop_back(output,
+                                           element_dtype=dtypes.float32)
+    MatchShape(val.shape)
+
+    # Gradient pass.
+    grad = gradients_impl.gradients(ret[1], y)
+    grad_while_op = grad[0].op
+    # Get the TensorList output of gradient While op containing the accumulated
+    # values of grad_y.
+    # grad_while_op.inputs:
+    # [counter_arg, total_iters_arg, grad_x_arg, grad_y_arg, *other_args]
+    grad_output = GetAccumulatorForInputAtIndex(grad_while_op, 4)
+    _, val = list_ops.tensor_list_pop_back(grad_output,
+                                           element_dtype=dtypes.float32)
+    MatchShape(val.shape)
+
+
+def ScalarShape():
+  return ops.convert_to_tensor([], dtype=dtypes.int32)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index c8b8833..a7f57e9 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2787,4 +2787,65 @@
       name=name)
 
 
+@tf_export("searchsorted")
+def searchsorted(sorted_sequence,
+                 values,
+                 side="left",
+                 out_type=dtypes.int32,
+                 name=None):
+  """Searches input tensor for values on the innermost dimension.
+
+  A 2-D example:
+
+  ```
+    sorted_sequence = [[0, 3, 9, 9, 10],
+                       [1, 2, 3, 4, 5]]
+    values = [[2, 4, 9],
+              [0, 2, 6]]
+
+    result = searchsorted(sorted_sequence, values, side="left")
+
+    result == [[1, 2, 2],
+               [0, 1, 5]]
+
+    result = searchsorted(sorted_sequence, values, side="right")
+
+    result == [[1, 2, 4],
+               [0, 2, 5]]
+  ```
+
+  Args:
+    sorted_sequence: N-D `Tensor` containing a sorted sequence.
+    values: N-D `Tensor` containing the search values.
+    side: 'left' or 'right'; 'left' corresponds to lower_bound and 'right' to
+      upper_bound.
+    out_type: The output type (`int32` or `int64`).  Default is `tf.int32`.
+    name: Optional name for the operation.
+
+  Returns:
+    An N-D `Tensor` the size of values containing the result of applying either
+    lower_bound or upper_bound (depending on side) to each value.  The result
+    is not a global index to the entire `Tensor`, but the index in the last
+    dimension.
+
+  Raises:
+    ValueError: If the last dimension of `sorted_sequence >= 2^31-1` elements.
+                If the total size of values exceeds `2^31 - 1` elements.
+                If the first `N-1` dimensions of the two tensors don't match.
+  """
+  sequence_size = shape_internal(sorted_sequence)[-1]
+  values_size = shape_internal(values)[-1]
+  sorted_sequence_2d = reshape(sorted_sequence, [-1, sequence_size])
+  values_2d = reshape(values, [-1, values_size])
+  if side == "right":
+    output = gen_array_ops.upper_bound(sorted_sequence_2d, values_2d, out_type,
+                                       name)
+  elif side == "left":
+    output = gen_array_ops.lower_bound(sorted_sequence_2d, values_2d, out_type,
+                                       name)
+  else:
+    raise ValueError("side must be either 'right' or 'left'.  Saw: %s." % side)
+  return reshape(output, shape_internal(values))
+
+
 quantize.__doc__ = gen_array_ops.quantize_v2.__doc__
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 0e20fad..87f8bd8 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -610,9 +610,10 @@
           "less-specific shape." %
           (input_t.name, input_t.shape, n_shape))
   else:
-    if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
-      raise TypeError("Type %s not supported" % type(var))
-    if isinstance(var, ops.IndexedSlices):
+    if not isinstance(merge_var,
+                      (ops.IndexedSlices, sparse_tensor.SparseTensor)):
+      raise TypeError("Type %s not supported" % type(merge_var))
+    if isinstance(merge_var, ops.IndexedSlices):
       m_values_shape = merge_var.values.get_shape()
       m_indices_shape = merge_var.indices.get_shape()
       m_shape_shape = tensor_shape.TensorShape(None)
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py
index 99d30b0..2ba1ea6 100644
--- a/tensorflow/python/ops/distributions/beta.py
+++ b/tensorflow/python/ops/distributions/beta.py
@@ -98,10 +98,13 @@
   #### Examples
 
   ```python
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+
   # Create a batch of three Beta distributions.
   alpha = [1, 2, 3]
   beta = [1, 2, 3]
-  dist = tf.distributions.Beta(alpha, beta)
+  dist = tfd.Beta(alpha, beta)
 
   dist.sample([4, 5])  # Shape [4, 5, 3]
 
@@ -117,7 +120,7 @@
   # Create batch_shape=[2, 3] via parameter broadcast:
   alpha = [[1.], [2]]      # Shape [2, 1]
   beta = [3., 4, 5]        # Shape [3]
-  dist = tf.distributions.Beta(alpha, beta)
+  dist = tfd.Beta(alpha, beta)
 
   # alpha broadcast as: [[1., 1, 1,],
   #                      [2, 2, 2]]
@@ -138,7 +141,7 @@
   ```python
   alpha = tf.constant(1.0)
   beta = tf.constant(2.0)
-  dist = tf.distributions.Beta(alpha, beta)
+  dist = tfd.Beta(alpha, beta)
   samples = dist.sample(5)  # Shape [5]
   loss = tf.reduce_mean(tf.square(samples))  # Arbitrary loss function
   # Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
index 9104a1d..415249a 100644
--- a/tensorflow/python/ops/distributions/dirichlet.py
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -104,10 +104,13 @@
   #### Examples
 
   ```python
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+
   # Create a single trivariate Dirichlet, with the 3rd class being three times
   # more frequent than the first. I.e., batch_shape=[], event_shape=[3].
   alpha = [1., 2, 3]
-  dist = tf.distributions.Dirichlet(alpha)
+  dist = tfd.Dirichlet(alpha)
 
   dist.sample([4, 5])  # shape: [4, 5, 3]
 
@@ -129,7 +132,7 @@
   # Create batch_shape=[2], event_shape=[3]:
   alpha = [[1., 2, 3],
            [4, 5, 6]]   # shape: [2, 3]
-  dist = tf.distributions.Dirichlet(alpha)
+  dist = tfd.Dirichlet(alpha)
 
   dist.sample([4, 5])  # shape: [4, 5, 2, 3]
 
@@ -144,7 +147,7 @@
 
   ```python
   alpha = tf.constant([1.0, 2.0, 3.0])
-  dist = tf.distributions.Dirichlet(alpha)
+  dist = tfd.Dirichlet(alpha)
   samples = dist.sample(5)  # Shape [5, 3]
   loss = tf.reduce_mean(tf.square(samples))  # Arbitrary loss function
   # Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 578e7b7..76d9806 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -601,7 +601,8 @@
     return type(self)(**parameters)
 
   def _batch_shape_tensor(self):
-    raise NotImplementedError("batch_shape_tensor is not implemented")
+    raise NotImplementedError(
+        "batch_shape_tensor is not implemented: {}".format(type(self).__name__))
 
   def batch_shape_tensor(self, name="batch_shape_tensor"):
     """Shape of a single sample from a single event index as a 1-D `Tensor`.
@@ -640,7 +641,8 @@
     return tensor_shape.as_shape(self._batch_shape())
 
   def _event_shape_tensor(self):
-    raise NotImplementedError("event_shape_tensor is not implemented")
+    raise NotImplementedError(
+        "event_shape_tensor is not implemented: {}".format(type(self).__name__))
 
   def event_shape_tensor(self, name="event_shape_tensor"):
     """Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
@@ -701,7 +703,8 @@
           name="is_scalar_batch")
 
   def _sample_n(self, n, seed=None):
-    raise NotImplementedError("sample_n is not implemented")
+    raise NotImplementedError("sample_n is not implemented: {}".format(
+        type(self).__name__))
 
   def _call_sample_n(self, sample_shape, seed, name, **kwargs):
     with self._name_scope(name, values=[sample_shape]):
@@ -733,15 +736,19 @@
     return self._call_sample_n(sample_shape, seed, name)
 
   def _log_prob(self, value):
-    raise NotImplementedError("log_prob is not implemented")
+    raise NotImplementedError("log_prob is not implemented: {}".format(
+        type(self).__name__))
 
   def _call_log_prob(self, value, name, **kwargs):
     with self._name_scope(name, values=[value]):
       value = ops.convert_to_tensor(value, name="value")
       try:
         return self._log_prob(value, **kwargs)
-      except NotImplementedError:
-        return math_ops.log(self._prob(value, **kwargs))
+      except NotImplementedError as original_exception:
+        try:
+          return math_ops.log(self._prob(value, **kwargs))
+        except NotImplementedError:
+          raise original_exception
 
   def log_prob(self, value, name="log_prob"):
     """Log probability density/mass function.
@@ -757,15 +764,19 @@
     return self._call_log_prob(value, name)
 
   def _prob(self, value):
-    raise NotImplementedError("prob is not implemented")
+    raise NotImplementedError("prob is not implemented: {}".format(
+        type(self).__name__))
 
   def _call_prob(self, value, name, **kwargs):
     with self._name_scope(name, values=[value]):
       value = ops.convert_to_tensor(value, name="value")
       try:
         return self._prob(value, **kwargs)
-      except NotImplementedError:
-        return math_ops.exp(self._log_prob(value, **kwargs))
+      except NotImplementedError as original_exception:
+        try:
+          return math_ops.exp(self._log_prob(value, **kwargs))
+        except NotImplementedError:
+          raise original_exception
 
   def prob(self, value, name="prob"):
     """Probability density/mass function.
@@ -781,15 +792,19 @@
     return self._call_prob(value, name)
 
   def _log_cdf(self, value):
-    raise NotImplementedError("log_cdf is not implemented")
+    raise NotImplementedError("log_cdf is not implemented: {}".format(
+        type(self).__name__))
 
   def _call_log_cdf(self, value, name, **kwargs):
     with self._name_scope(name, values=[value]):
       value = ops.convert_to_tensor(value, name="value")
       try:
         return self._log_cdf(value, **kwargs)
-      except NotImplementedError:
-        return math_ops.log(self._cdf(value, **kwargs))
+      except NotImplementedError as original_exception:
+        try:
+          return math_ops.log(self._cdf(value, **kwargs))
+        except NotImplementedError:
+          raise original_exception
 
   def log_cdf(self, value, name="log_cdf"):
     """Log cumulative distribution function.
@@ -815,15 +830,19 @@
     return self._call_log_cdf(value, name)
 
   def _cdf(self, value):
-    raise NotImplementedError("cdf is not implemented")
+    raise NotImplementedError("cdf is not implemented: {}".format(
+        type(self).__name__))
 
   def _call_cdf(self, value, name, **kwargs):
     with self._name_scope(name, values=[value]):
       value = ops.convert_to_tensor(value, name="value")
       try:
         return self._cdf(value, **kwargs)
-      except NotImplementedError:
-        return math_ops.exp(self._log_cdf(value, **kwargs))
+      except NotImplementedError as original_exception:
+        try:
+          return math_ops.exp(self._log_cdf(value, **kwargs))
+        except NotImplementedError:
+          raise original_exception
 
   def cdf(self, value, name="cdf"):
     """Cumulative distribution function.
@@ -845,15 +864,20 @@
     return self._call_cdf(value, name)
 
   def _log_survival_function(self, value):
-    raise NotImplementedError("log_survival_function is not implemented")
+    raise NotImplementedError(
+        "log_survival_function is not implemented: {}".format(
+            type(self).__name__))
 
   def _call_log_survival_function(self, value, name, **kwargs):
     with self._name_scope(name, values=[value]):
       value = ops.convert_to_tensor(value, name="value")
       try:
         return self._log_survival_function(value, **kwargs)
-      except NotImplementedError:
-        return math_ops.log1p(-self.cdf(value, **kwargs))
+      except NotImplementedError as original_exception:
+        try:
+          return math_ops.log1p(-self.cdf(value, **kwargs))
+        except NotImplementedError:
+          raise original_exception
 
   def log_survival_function(self, value, name="log_survival_function"):
     """Log survival function.
@@ -880,15 +904,19 @@
     return self._call_log_survival_function(value, name)
 
   def _survival_function(self, value):
-    raise NotImplementedError("survival_function is not implemented")
+    raise NotImplementedError("survival_function is not implemented: {}".format(
+        type(self).__name__))
 
   def _call_survival_function(self, value, name, **kwargs):
     with self._name_scope(name, values=[value]):
       value = ops.convert_to_tensor(value, name="value")
       try:
         return self._survival_function(value, **kwargs)
-      except NotImplementedError:
-        return 1. - self.cdf(value, **kwargs)
+      except NotImplementedError as original_exception:
+        try:
+          return 1. - self.cdf(value, **kwargs)
+        except NotImplementedError:
+          raise original_exception
 
   def survival_function(self, value, name="survival_function"):
     """Survival function.
@@ -912,7 +940,8 @@
     return self._call_survival_function(value, name)
 
   def _entropy(self):
-    raise NotImplementedError("entropy is not implemented")
+    raise NotImplementedError("entropy is not implemented: {}".format(
+        type(self).__name__))
 
   def entropy(self, name="entropy"):
     """Shannon entropy in nats."""
@@ -920,7 +949,8 @@
       return self._entropy()
 
   def _mean(self):
-    raise NotImplementedError("mean is not implemented")
+    raise NotImplementedError("mean is not implemented: {}".format(
+        type(self).__name__))
 
   def mean(self, name="mean"):
     """Mean."""
@@ -928,7 +958,8 @@
       return self._mean()
 
   def _quantile(self, value):
-    raise NotImplementedError("quantile is not implemented")
+    raise NotImplementedError("quantile is not implemented: {}".format(
+        type(self).__name__))
 
   def _call_quantile(self, value, name, **kwargs):
     with self._name_scope(name, values=[value]):
@@ -955,7 +986,8 @@
     return self._call_quantile(value, name)
 
   def _variance(self):
-    raise NotImplementedError("variance is not implemented")
+    raise NotImplementedError("variance is not implemented: {}".format(
+        type(self).__name__))
 
   def variance(self, name="variance"):
     """Variance.
@@ -979,11 +1011,15 @@
     with self._name_scope(name):
       try:
         return self._variance()
-      except NotImplementedError:
-        return math_ops.square(self._stddev())
+      except NotImplementedError as original_exception:
+        try:
+          return math_ops.square(self._stddev())
+        except NotImplementedError:
+          raise original_exception
 
   def _stddev(self):
-    raise NotImplementedError("stddev is not implemented")
+    raise NotImplementedError("stddev is not implemented: {}".format(
+        type(self).__name__))
 
   def stddev(self, name="stddev"):
     """Standard deviation.
@@ -1008,11 +1044,15 @@
     with self._name_scope(name):
       try:
         return self._stddev()
-      except NotImplementedError:
-        return math_ops.sqrt(self._variance())
+      except NotImplementedError as original_exception:
+        try:
+          return math_ops.sqrt(self._variance())
+        except NotImplementedError:
+          raise original_exception
 
   def _covariance(self):
-    raise NotImplementedError("covariance is not implemented")
+    raise NotImplementedError("covariance is not implemented: {}".format(
+        type(self).__name__))
 
   def covariance(self, name="covariance"):
     """Covariance.
@@ -1054,7 +1094,8 @@
       return self._covariance()
 
   def _mode(self):
-    raise NotImplementedError("mode is not implemented")
+    raise NotImplementedError("mode is not implemented: {}".format(
+        type(self).__name__))
 
   def mode(self, name="mode"):
     """Mode."""
@@ -1080,7 +1121,7 @@
     where `F` denotes the support of the random variable `X ~ P`.
 
     Args:
-      other: `tf.distributions.Distribution` instance.
+      other: `tfp.distributions.Distribution` instance.
       name: Python `str` prepended to names of ops created by this function.
 
     Returns:
@@ -1111,7 +1152,7 @@
     denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy.
 
     Args:
-      other: `tf.distributions.Distribution` instance.
+      other: `tfp.distributions.Distribution` instance.
       name: Python `str` prepended to names of ops created by this function.
 
     Returns:
@@ -1123,7 +1164,7 @@
       return self._kl_divergence(other)
 
   def __str__(self):
-    return ("tf.distributions.{type_name}("
+    return ("tfp.distributions.{type_name}("
             "\"{self_name}\""
             "{maybe_batch_shape}"
             "{maybe_event_shape}"
@@ -1139,7 +1180,7 @@
                 dtype=self.dtype.name))
 
   def __repr__(self):
-    return ("<tf.distributions.{type_name} "
+    return ("<tfp.distributions.{type_name} "
             "'{self_name}'"
             " batch_shape={batch_shape}"
             " event_shape={event_shape}"
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py
index b631f02..3293cda 100644
--- a/tensorflow/python/ops/distributions/gamma.py
+++ b/tensorflow/python/ops/distributions/gamma.py
@@ -100,8 +100,11 @@
   #### Examples
 
   ```python
-  dist = tf.distributions.Gamma(concentration=3.0, rate=2.0)
-  dist2 = tf.distributions.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+
+  dist = tfd.Gamma(concentration=3.0, rate=2.0)
+  dist2 = tfd.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
   ```
 
   Compute the gradients of samples w.r.t. the parameters:
@@ -109,7 +112,7 @@
   ```python
   concentration = tf.constant(3.0)
   rate = tf.constant(2.0)
-  dist = tf.distributions.Gamma(concentration, rate)
+  dist = tfd.Gamma(concentration, rate)
   samples = dist.sample(5)  # Shape [5]
   loss = tf.reduce_mean(tf.square(samples))  # Arbitrary loss function
   # Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/kullback_leibler.py b/tensorflow/python/ops/distributions/kullback_leibler.py
index e3c6f3e..fdeb97b 100644
--- a/tensorflow/python/ops/distributions/kullback_leibler.py
+++ b/tensorflow/python/ops/distributions/kullback_leibler.py
@@ -127,8 +127,8 @@
   where `F` denotes the support of the random variable `X ~ P`.
 
   Args:
-    ref: `tf.distributions.Distribution` instance.
-    other: `tf.distributions.Distribution` instance.
+    ref: `tfd.Distribution` instance.
+    other: `tfd.Distribution` instance.
     allow_nan_stats: Python `bool`, default `True`. When `True`,
       statistics (e.g., mean, mode, variance) use the value "`NaN`" to
       indicate the result is undefined. When `False`, an exception is raised
diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py
index d0a987b..2feaf80 100644
--- a/tensorflow/python/ops/distributions/normal.py
+++ b/tensorflow/python/ops/distributions/normal.py
@@ -71,15 +71,18 @@
   Examples of initialization of one or a batch of distributions.
 
   ```python
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+
   # Define a single scalar Normal distribution.
-  dist = tf.distributions.Normal(loc=0., scale=3.)
+  dist = tfd.Normal(loc=0., scale=3.)
 
   # Evaluate the cdf at 1, returning a scalar.
   dist.cdf(1.)
 
   # Define a batch of two scalar valued Normals.
   # The first has mean 1 and standard deviation 11, the second 2 and 22.
-  dist = tf.distributions.Normal(loc=[1, 2.], scale=[11, 22.])
+  dist = tfd.Normal(loc=[1, 2.], scale=[11, 22.])
 
   # Evaluate the pdf of the first distribution on 0, and the second on 1.5,
   # returning a length two tensor.
@@ -94,7 +97,7 @@
   ```python
   # Define a batch of two scalar valued Normals.
   # Both have mean 1, but different standard deviations.
-  dist = tf.distributions.Normal(loc=1., scale=[11, 22.])
+  dist = tfd.Normal(loc=1., scale=[11, 22.])
 
   # Evaluate the pdf of both distributions on the same point, 3.0,
   # returning a length 2 tensor.
diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py
index e0cf6f8..e8d214b 100644
--- a/tensorflow/python/ops/distributions/student_t.py
+++ b/tensorflow/python/ops/distributions/student_t.py
@@ -91,8 +91,11 @@
   Examples of initialization of one or a batch of distributions.
 
   ```python
+  import tensorflow_probability as tfp
+  tfd = tfp.distributions
+
   # Define a single scalar Student t distribution.
-  single_dist = tf.distributions.StudentT(df=3)
+  single_dist = tfd.StudentT(df=3)
 
   # Evaluate the pdf at 1, returning a scalar Tensor.
   single_dist.prob(1.)
@@ -100,9 +103,7 @@
   # Define a batch of two scalar valued Student t's.
   # The first has degrees of freedom 2, mean 1, and scale 11.
   # The second 3, 2 and 22.
-  multi_dist = tf.distributions.StudentT(df=[2, 3],
-                                                 loc=[1, 2.],
-                                                 scale=[11, 22.])
+  multi_dist = tfd.StudentT(df=[2, 3], loc=[1, 2.], scale=[11, 22.])
 
   # Evaluate the pdf of the first distribution on 0, and the second on 1.5,
   # returning a length two tensor.
@@ -117,7 +118,7 @@
   ```python
   # Define a batch of two Student's t distributions.
   # Both have df 2 and mean 1, but different scales.
-  dist = tf.distributions.StudentT(df=2, loc=1, scale=[11, 22.])
+  dist = tfd.StudentT(df=2, loc=1, scale=[11, 22.])
 
   # Evaluate the pdf of both distributions on the same point, 3.0,
   # returning a length 2 tensor.
@@ -130,7 +131,7 @@
   df = tf.constant(2.0)
   loc = tf.constant(2.0)
   scale = tf.constant(11.0)
-  dist = tf.distributions.StudentT(df=df, loc=loc, scale=scale)
+  dist = tfd.StudentT(df=df, loc=loc, scale=scale)
   samples = dist.sample(5)  # Shape [5]
   loss = tf.reduce_mean(tf.square(samples))  # Arbitrary loss function
   # Unbiased stochastic gradients of the loss function
@@ -138,7 +139,6 @@
   ```
 
   """
-  # pylint: enable=line-too-long
 
   def __init__(self,
                df,
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 196161c..056015d 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -184,7 +184,7 @@
       between_op_list.append(op)
       # Clear the boolean so we won't add the inputs again.
       reached_ops.remove(op)
-      for inp in _Inputs(op, xs):
+      for inp in _NonEagerInputs(op, xs):
         queue.append(inp.op)
   # X in between_ops iff X is on a path of zero or more backpropagatable tensors
   # between from_ops and to_ops
@@ -196,7 +196,7 @@
   # Initialize pending count for between ops.
   pending_count = collections.defaultdict(int)
   for op in between_op_list:
-    for x in _Inputs(op, xs):
+    for x in _NonEagerInputs(op, xs):
       if x.op in between_ops:
         pending_count[x.op] += 1
 
@@ -347,7 +347,7 @@
   stop_ops = set()
   for op in from_ops:
     is_stop_op = True
-    for inp in _Inputs(op, xs):
+    for inp in _NonEagerInputs(op, xs):
       if pending_count[inp.op] > 0:
         is_stop_op = False
         break
@@ -371,10 +371,10 @@
   return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall"
 
 
-def _SymGrad(op, out_grads, xs):
+def _SymGrad(op, out_grads):
   """Backprop through a function call node op given its outputs' gradients."""
-  f_in = [x for x in _Inputs(op, xs)] + out_grads
-  f_types = [x.dtype for x in _Inputs(op, xs)]
+  f_in = [x for x in op.inputs] + out_grads
+  f_types = [x.dtype for x in op.inputs]
   f = attr_value_pb2.NameAttrList()
   if _IsPartitionedCall(op):
     f.name = op.get_attr("f").name
@@ -441,7 +441,7 @@
     if curr_op in from_ops:
       target_op = curr_op
       break
-    queue.extend(t.op for t in _Inputs(curr_op, xs))
+    queue.extend(t.op for t in _NonEagerInputs(curr_op, xs))
   assert target_op
   raise ValueError(
       "Cannot compute gradient inside while loop with respect to op '%s'. "
@@ -474,7 +474,8 @@
     A tensor, potentially from a different Graph/_function.FuncGraph.
   """
   # pylint: disable=protected-access
-  if _IsFunction(t.op.graph) and t.op.type == "Placeholder":
+  if (not isinstance(t, ops.EagerTensor) and
+      _IsFunction(t.op.graph) and t.op.type == "Placeholder"):
     for input_t, placeholder_t in _Captures(t.op.graph).items():
       if t == placeholder_t:
         return _MaybeCaptured(input_t)
@@ -484,9 +485,12 @@
 
 # TODO(skyewm): plumbing xs through everywhere is ugly, consider making
 # _GradientsHelper a class with xs as a member variable.
-def _Inputs(op, xs):
+def _NonEagerInputs(op, xs):
   """Returns the inputs of op, crossing closure boundaries where necessary.
 
+  Does not return any captured EagerTensors, i.e., the number of tensors
+  returned may be less than than the actual number of inputs.
+
   Args:
     op: Operation
     xs: list of Tensors we are differentiating w.r.t.
@@ -497,12 +501,19 @@
     captured inputs.
   """
   if _IsFunction(op.graph):  # pylint: disable=protected-access
-    # If we're differentiating w.r.t. `t`, do not attempt to traverse through it
-    # to a captured value. The algorithm needs to "see" `t` in this case, even
-    # if it's a function input for a captured value, whereas usually we'd like
-    # to traverse through these closures as if the captured value was the direct
-    # input to op.
-    return [t if (t in xs) else _MaybeCaptured(t) for t in op.inputs]
+    inputs = []
+    for t in op.inputs:
+      # If we're differentiating w.r.t. `t`, do not attempt to traverse through
+      # it to a captured value. The algorithm needs to "see" `t` in this case,
+      # even if it's a function input for a captured value, whereas usually we'd
+      # like to traverse through these closures as if the captured value was the
+      # direct input to op.
+      if t not in xs:
+        t = _MaybeCaptured(t)
+        # Skip captured eager inputs.
+        if isinstance(t, ops.EagerTensor): continue
+      inputs.append(t)
+    return inputs
   else:
     return op.inputs
 
@@ -805,7 +816,7 @@
                 # For function call ops, we add a 'SymbolicGradient'
                 # node to the graph to compute gradients.
                 in_grads = _MaybeCompile(grad_scope, op, func_call,
-                                         lambda: _SymGrad(op, out_grads, xs))
+                                         lambda: _SymGrad(op, out_grads))
               in_grads = _AsList(in_grads)
               _VerifyGeneratedGradients(in_grads, op)
               if gate_gradients and len([x for x in in_grads
@@ -820,8 +831,9 @@
         else:
           # If no grad_fn is defined or none of out_grads is available,
           # just propagate a list of None backwards.
-          in_grads = [None] * len(_Inputs(op, xs))
-        for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)):
+          in_grads = [None] * len(_NonEagerInputs(op, xs))
+        for i, (t_in, in_grad) in enumerate(zip(_NonEagerInputs(op, xs),
+                                                in_grads)):
           if in_grad is not None:
             if (isinstance(in_grad, ops.Tensor) and
                 t_in.dtype != dtypes.resource):
@@ -862,7 +874,7 @@
 def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
                                   xs):
   """Update pending count for the inputs of op and enqueue ready ops."""
-  for x in _Inputs(op, xs):
+  for x in _NonEagerInputs(op, xs):
     pending_count[x.op] -= 1
     ready = (pending_count[x.op] == 0)
     if loop_state and not ready:
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 6243be6..4f6e5dc 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -531,6 +531,24 @@
       with self.cached_session() as sess:
         self.assertEqual(sess.run(z_grad), 3.0)
 
+  def testCapturedEagerTensors(self):
+    # Test that we can handle captured eager tensors unrelated to the gradient
+    # computation (i.e. we need to ignore them).
+    # TODO(skyewm): make it an error if you try to take the gradient wrt a
+    # captured EagerTensor
+    with context.eager_mode():
+      c = constant_op.constant(2.0, name="c")
+
+      @function.defun
+      def Foo():
+        x = constant_op.constant(10.0, name="x")
+        y = math_ops.multiply(x, c, name="y")
+        z = math_ops.multiply(y, 3.0, name="z")
+        g = gradients_impl.gradients(z, x)
+        return g[0]
+
+      self.assertEqual(Foo().numpy(), 6.0)
+
 
 class StopGradientTest(test_util.TensorFlowTestCase):
 
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index de260f3..325418d 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -29,7 +29,6 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import gen_image_ops
 from tensorflow.python.ops import gen_nn_ops
 from tensorflow.python.ops import math_ops
@@ -301,21 +300,21 @@
 
 def _random_flip(image, flip_index, seed, scope_name):
   """Randomly (50% chance) flip an image along axis `flip_index`.
-    Args:
-      image: 4-D Tensor of shape `[batch, height, width, channels]` or
-             3-D Tensor of shape `[height, width, channels]`.
-      flip_index: The dimension along which to flip the image.
-                  Vertical: 0, Horizontal: 1
-      seed: A Python integer. Used to create a random seed. See
-        `tf.set_random_seed`
-        for behavior.
-      scope_name: Name of the scope in which the ops are added.
 
-    Returns:
-      A tensor of the same type and shape as `image`.
+  Args:
+    image: 4-D Tensor of shape `[batch, height, width, channels]` or
+           3-D Tensor of shape `[height, width, channels]`.
+    flip_index: Dimension along which to flip image. Vertical: 0, Horizontal: 1
+    seed: A Python integer. Used to create a random seed. See
+      `tf.set_random_seed`
+      for behavior.
+    scope_name: Name of the scope in which the ops are added.
 
-    Raises:
-      ValueError: if the shape of `image` not supported.
+  Returns:
+    A tensor of the same type and shape as `image`.
+
+  Raises:
+    ValueError: if the shape of `image` not supported.
   """
   with ops.name_scope(None, scope_name, [image]) as scope:
     image = ops.convert_to_tensor(image, name='image')
@@ -334,15 +333,16 @@
         result = result[0]  # TODO(b/111124878) remove this logic (CondV2).
       return fix_image_flip_shape(image, result)
     elif shape.ndims == 4:
+      batch_size = array_ops.shape(image)[0]
       uniform_random = random_ops.random_uniform(
-          [array_ops.shape(image)[0]], 0, 1.0, seed=seed
+          [batch_size], 0, 1.0, seed=seed
       )
-      mirror_cond = math_ops.less(uniform_random, .5)
-      return array_ops.where(
-          mirror_cond,
-          image,
-          functional_ops.map_fn(lambda x: array_ops.reverse(x, [flip_index]), image, dtype=image.dtype)
+      flips = math_ops.round(
+          array_ops.reshape(uniform_random, [batch_size, 1, 1, 1])
       )
+      flips = math_ops.cast(flips, image.dtype)
+      flipped_input = array_ops.reverse(image, [flip_index + 1])
+      return flips * flipped_input + (1 - flips) * image
     else:
       raise ValueError('\'image\' must have either 3 or 4 dimensions.')
 
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
index df41933..4c53f33 100644
--- a/tensorflow/python/ops/logging_ops.py
+++ b/tensorflow/python/ops/logging_ops.py
@@ -19,13 +19,24 @@
 from __future__ import division
 from __future__ import print_function
 
+import pprint
+import random
+import sys
+
+import six
+
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import gen_logging_ops
+from tensorflow.python.ops import string_ops
 # go/tf-wildcard-import
 # pylint: disable=wildcard-import
 from tensorflow.python.ops.gen_logging_ops import *
 # pylint: enable=wildcard-import
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import nest
 from tensorflow.python.util.deprecation import deprecated
 from tensorflow.python.util.tf_export import tf_export
 
@@ -40,7 +51,32 @@
 # For users with Python 3 or Python 2.7
 # with `from __future__ import print_function`, we could also allow lowercase.
 # See https://github.com/tensorflow/tensorflow/issues/18053
-@tf_export("Print")
+
+
+# pylint: disable=invalid-name
+@deprecated("2018-08-20", "Use tf.print instead of tf.Print. Note that "
+                          "tf.print returns a no-output operator that directly "
+                          "prints the output. Outside of defuns or eager mode, "
+                          "this operator will not be executed unless it is "
+                          "directly specified in session.run or used as a "
+                          "control dependency for other operators. This is "
+                          "only a concern in graph mode. Below is an example "
+                          "of how to ensure tf.print executes in graph mode:\n"
+                          """```python
+    sess = tf.Session()
+    with sess.as_default():
+        tensor = tf.range(10)
+        print_op = tf.print(tensor)
+        with tf.control_dependencies([print_op]):
+          out = tf.add(tensor, tensor)
+        sess.run(out)
+    ```
+Additionally, to use tf.print in python 2.7, users must make sure to import
+the following:
+
+  `from __future__ import print_function`
+""")
+@tf_export(v1=["Print"])
 def Print(input_, data, message=None, first_n=None, summarize=None,
           name=None):
   """Prints a list of tensors.
@@ -66,6 +102,228 @@
     A `Tensor`. Has the same type and contents as `input_`.
   """
   return gen_logging_ops._print(input_, data, message, first_n, summarize, name)
+# pylint: enable=invalid-name
+
+
+def _generate_placeholder_string(x, default_placeholder="{}"):
+  """Generate and return a string that does not appear in `x`."""
+  placeholder = default_placeholder
+  rng = random.Random(5)
+  while placeholder in x:
+    placeholder = placeholder + str(rng.randint(0, 9))
+  return placeholder
+
+
+# Temporarily disable pylint g-doc-args error to allow giving more context
+# about what the kwargs are.
+# Because we are using arbitrary-length positional arguments, python 2
+# does not support explicitly specifying the keyword arguments in the
+# function definition.
+# pylint: disable=g-doc-args
+@tf_export("print")
+def print_v2(*inputs, **kwargs):
+  """Print the specified inputs.
+
+  Returns an operator that prints the specified inputs to a desired
+  output stream or logging level. The inputs may be dense or sparse Tensors,
+  primitive python objects, data structures that contain Tensors, and printable
+  python objects. Printed tensors will recursively show the first and last
+  `summarize` elements of each dimension.
+
+  With eager execution enabled and/or inside a `tf.contrib.eager.defun` this
+  operator will automatically execute, and users only need to call `tf.print`
+  without using the return value. When constructing graphs outside of a
+  `tf.contrib.eager.defun`, one must either include the returned op
+  in the input to `session.run`, or use the operator as a control dependency for
+  executed ops by specifying `with tf.control_dependencies([print_op])`.
+
+  @compatibility(python2)
+  In python 2.7, make sure to import the following:
+  `from __future__ import print_function`
+  @end_compatibility
+
+  Example:
+    Single-input usage:
+    ```python
+    tf.enable_eager_execution()
+    tensor = tf.range(10)
+    tf.print(tensor, output_stream=sys.stderr)
+    ```
+    (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+    Multi-input usage:
+    ```python
+    tf.enable_eager_execution()
+    tensor = tf.range(10)
+    tf.print("tensors:", tensor, {2: tensor * 2}, output_stream=sys.stdout)
+    ```
+    (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+    sys.stdout)
+
+    Usage in a defun:
+    ```python
+    tf.enable_eager_execution()
+
+    @tf.contrib.eager.defun
+    def f():
+        tensor = tf.range(10)
+        tf.print(tensor, output_stream=sys.stderr)
+        return tensor
+
+    range_tensor = f()
+    ```
+    (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+    Usage when constructing graphs:
+    ```python
+    sess = tf.Session()
+    with sess.as_default():
+        tensor = tf.range(10)
+        print_op = tf.print("tensors:", tensor, {2: tensor * 2},
+                            output_stream=sys.stdout)
+        with tf.control_dependencies([print_op]):
+          tripled_tensor = tensor * 3
+        sess.run(tripled_tensor)
+    ```
+    (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+    sys.stdout)
+
+  Note: This op is only partially compatible with Jupyter notebooks and colabs.
+    Because it prints to the C++ standard out / standard error, this will go
+    in the notebook kernel's console output, not in the notebook cell output.
+
+  Args:
+    *inputs: Positional arguments that are the inputs to print. Inputs in the
+      printed output will be separated by spaces. Inputs may be python
+      primitives, tensors, data structures such as dicts and lists that
+      may contain tensors (with the data structures possibly nested in
+      arbitrary ways), and printable python objects.
+    output_stream: The output stream or logging level to print to. Defaults to
+      sys.stderr, but sys.stdout, tf.logging.info, tf.logging.warning, and
+      tf.logging.error are also supported.
+    summarize: The first and last `summarize` elements within each dimension are
+      recursively printed per Tensor. If None, then the first 3 and last 3
+      elements of each dimension are printed for each tensor. If set to -1, it
+      will print all elements of every tensor.
+    name: A name for the operation (optional).
+
+  Returns:
+    A print operator that prints the specified inputs in the specified output
+    stream or logging level.
+
+  Raises:
+    ValueError: If an unsupported output stream is specified.
+  """
+  # Because we are using arbitrary-length positional arguments, python 2
+  # does not support explicitly specifying the keyword arguments in the
+  # function definition. So, we manually get the keyword arguments w/ default
+  # values here.
+  output_stream = kwargs.pop("output_stream", sys.stderr)
+  name = kwargs.pop("name", None)
+  summarize = kwargs.pop("summarize", 3)
+  if kwargs:
+    raise ValueError("Unrecognized keyword arguments for tf.print: %s" % kwargs)
+  format_name = None
+  if name:
+    format_name = name + "_format"
+
+  # Match the C++ string constants representing the different output streams.
+  # Keep this updated!
+  output_stream_to_constant = {
+      sys.stdout: "stdout",
+      sys.stderr: "stderr",
+      tf_logging.INFO: "log(info)",
+      tf_logging.info: "log(info)",
+      tf_logging.WARN: "log(warning)",
+      tf_logging.warning: "log(warning)",
+      tf_logging.warn: "log(warning)",
+      tf_logging.ERROR: "log(error)",
+      tf_logging.error: "log(error)",
+  }
+
+  output_stream_string = output_stream_to_constant.get(output_stream)
+  if not output_stream_string:
+    raise ValueError(
+        "Unsupported output stream or logging level " +
+        str(output_stream) + ". Supported streams are sys.stdout, "
+                             "sys.stderr, tf.logging.info, "
+                             "tf.logging.warning, tf.logging.error")
+
+  # If we are only printing a single string scalar, there is no need to format
+  if (len(inputs) == 1 and tensor_util.is_tensor(inputs[0])
+      and (not isinstance(inputs[0], sparse_tensor.SparseTensor))
+      and inputs[0].shape and (inputs[0].dtype == dtypes.string)):
+    formatted_string = inputs[0]
+  # Otherwise, we construct an appropriate template for the tensors we are
+  # printing, and format the template using those tensors.
+  else:
+    # For each input to this print function, we extract any nested tensors,
+    # and construct an appropriate template to format representing the
+    # printed input.
+    templates = []
+    tensors = []
+    tensor_free_structure = nest.map_structure(
+        lambda x: "" if tensor_util.is_tensor(x) else x,
+        inputs)
+    tensor_free_template = " ".join(pprint.pformat(x)
+                                    for x in tensor_free_structure)
+    placeholder = _generate_placeholder_string(tensor_free_template)
+
+    for input_ in inputs:
+      placeholders = []
+      # Use the nest utilities to flatten & process any nested elements in this
+      # input. The placeholder for a tensor in the template should be the
+      # placeholder string, and the placeholder for a non-tensor can just be
+      # the printed value of the non-tensor itself.
+      for x in nest.flatten(input_):
+        # support sparse tensors
+        if isinstance(x, sparse_tensor.SparseTensor):
+          tensors.extend([x.indices, x.values, x.dense_shape])
+          placeholders.append(
+              "SparseTensor(indices={}, values={}, shape={})".format(
+                  placeholder, placeholder, placeholder)
+          )
+        elif tensor_util.is_tensor(x):
+          tensors.append(x)
+          placeholders.append(placeholder)
+        else:
+          placeholders.append(x)
+
+      if isinstance(input_, six.string_types):
+        # If the current input to format/print is a normal string, that string
+        # can act as the template.
+        cur_template = input_
+      else:
+        # We pack the placeholders into a data structure that matches the
+        # input data structure format, then format that data structure
+        # into a string template.
+        #
+        # NOTE: We must use pprint.pformat here for building the template for
+        # unordered data structures such as `dict`, because `str` doesn't
+        # guarantee orderings, while pprint prints in sorted order. pprint
+        # will match the ordering of `nest.flatten`.
+        # This even works when nest.flatten reorders OrderedDicts, because
+        # pprint is printing *after* the OrderedDicts have been reordered.
+        cur_template = pprint.pformat(
+            nest.pack_sequence_as(input_, placeholders))
+      templates.append(cur_template)
+
+    # We join the templates for the various inputs into a single larger
+    # template. We also remove all quotes surrounding the placeholders, so that
+    # the formatted/printed output will not contain quotes around tensors.
+    # (example of where these quotes might appear: if we have added a
+    # placeholder string into a list, then pretty-formatted that list)
+    template = " ".join(templates)
+    template = template.replace("'" + placeholder + "'", placeholder)
+    formatted_string = string_ops.string_format(
+        inputs=tensors, template=template, placeholder=placeholder,
+        summarize=summarize,
+        name=format_name)
+
+  return gen_logging_ops.print_v2(formatted_string,
+                                  output_stream=output_stream_string,
+                                  name=name)
+# pylint: enable=g-doc-args
 
 
 @ops.RegisterGradient("Print")
diff --git a/tensorflow/python/ops/losses/util_test.py b/tensorflow/python/ops/losses/util_test.py
index 7fa7a41..df2e60e 100644
--- a/tensorflow/python/ops/losses/util_test.py
+++ b/tensorflow/python/ops/losses/util_test.py
@@ -28,7 +28,7 @@
 
   def testGetRegularizationLoss(self):
     # Empty regularization collection should evaluate to 0.0.
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(0.0, util.get_regularization_loss().eval())
 
     # Loss should sum.
@@ -36,14 +36,14 @@
         ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))
     ops.add_to_collection(
         ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(5.0, util.get_regularization_loss().eval())
 
     # Check scope capture mechanism.
     with ops.name_scope('scope1'):
       ops.add_to_collection(
           ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(-1.0))
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(-1.0, util.get_regularization_loss('scope1').eval())
 
 
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 7c59232..f57abf6 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2898,29 +2898,29 @@
         shape_a = a.get_shape().as_list()
         axes = [i if i >= 0 else i + len(shape_a) for i in axes]
         free = [i for i in xrange(len(shape_a)) if i not in axes]
-        free_dims_static = [shape_a[i] for i in free]
+        axes_dims = [shape_a[i] for i in axes]
+        free_dims = [shape_a[i] for i in free]
+        free_dims_static = free_dims
+        axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
+        free = ops.convert_to_tensor(free, dtype=dtypes.int32, name="free")
+        shape_a = array_ops.shape(a)
       else:
         free_dims_static = None
-      shape_a = array_ops.shape(a)
-      rank_a = array_ops.rank(a)
-      # TODO(b/115583659): Automate this.
-      with ops.device("/cpu:0"):
+        shape_a = array_ops.shape(a)
+        rank_a = array_ops.rank(a)
         axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
-        axes = cast(axes >= 0, dtypes.int32) * axes + cast(
-            axes < 0, dtypes.int32) * (
-                axes + rank_a)
+        axes = array_ops.where(axes >= 0, axes, axes + rank_a)
         free, _ = array_ops.setdiff1d(range(rank_a), axes)
-        free_dims = array_ops.gather(shape_a, free)
-        axes_dims = array_ops.gather(shape_a, axes)
-        prod_free_dims = reduce_prod(free_dims)
-        prod_axes_dims = reduce_prod(axes_dims)
-        perm = array_ops.concat([axes_dims, free_dims], 0)
-        if flipped:
-          perm = array_ops.concat([axes, free], 0)
-          new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
-        else:
-          perm = array_ops.concat([free, axes], 0)
-          new_shape = array_ops.stack([prod_free_dims, prod_axes_dims])
+      free_dims = array_ops.gather(shape_a, free)
+      axes_dims = array_ops.gather(shape_a, axes)
+      prod_free_dims = reduce_prod(free_dims)
+      prod_axes_dims = reduce_prod(axes_dims)
+      if flipped:
+        perm = array_ops.concat([axes, free], 0)
+        new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
+      else:
+        perm = array_ops.concat([free, axes], 0)
+        new_shape = array_ops.stack([prod_free_dims, prod_axes_dims])
       reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
       return reshaped_a, free_dims, free_dims_static
 
diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD
index 015181a..07fc943 100644
--- a/tensorflow/python/ops/parallel_for/BUILD
+++ b/tensorflow/python/ops/parallel_for/BUILD
@@ -123,6 +123,8 @@
         "//third_party/py/numpy",
         "//tensorflow/python:layers",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:functional_ops",
         "//tensorflow/python:random_ops",
         "//tensorflow/python/ops/losses",
     ],
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index d403b0c..6e276de 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -31,6 +31,8 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import bitwise_ops
+from tensorflow.python.ops import clip_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import gradients as gradient_ops
@@ -300,28 +302,129 @@
     self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
 
 
-class MathTest(PForTest):
+class BitwiseTest(PForTest):
 
-  def test_unary_cwise_ops(self):
-    for op in [
-        math_ops.tanh, nn.relu, math_ops.sigmoid, math_ops.negative,
-        math_ops.square
-    ]:
-      x = random_ops.random_uniform([3, 5])
+  def test_unary_cwise(self):
+    for op in [bitwise_ops.invert]:
+      x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32)
 
       # pylint: disable=cell-var-from-loop
       def loop_fn(i):
         x1 = array_ops.gather(x, i)
-        y = op(x1)
-        loss = math_ops.reduce_sum(y * y)
-        return op(x), y, gradient_ops.gradients(loss, x1)
+        return op(x1)
+      # pylint: enable=cell-var-from-loop
+
+      self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32])
+
+  def test_binary_cwise(self):
+    binary_ops = [
+        bitwise_ops.bitwise_and,
+        bitwise_ops.bitwise_or,
+        bitwise_ops.bitwise_xor,
+        bitwise_ops.left_shift,
+        bitwise_ops.right_shift,
+    ]
+    for op in binary_ops:
+      x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32)
+      y = random_ops.random_uniform([3, 5], maxval=10, dtype=dtypes.int32)
+
+      output_dtypes = []
+      # pylint: disable=cell-var-from-loop
+      def loop_fn(i):
+        x1 = array_ops.gather(x, i)
+        y1 = array_ops.gather(y, i)
+        outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)]
+        del output_dtypes[:]
+        output_dtypes.extend([t.dtype for t in outputs])
+        return outputs
+      # pylint: enable=cell-var-from-loop
+      self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=output_dtypes)
+
+
+class MathTest(PForTest):
+
+  def test_unary_cwise_ops(self):
+    complex_ops = [
+        math_ops.angle,
+        math_ops.imag,
+        math_ops.complex_abs,
+        math_ops.real,
+        math_ops.conj,
+    ]
+    real_ops = [
+        lambda x: math_ops.acosh(1 + math_ops.square(x)),
+        math_ops.abs,
+        math_ops.acos,
+        math_ops.asin,
+        math_ops.asinh,
+        math_ops.atan,
+        math_ops.atanh,
+        math_ops.bessel_i0e,
+        math_ops.bessel_i1e,
+        math_ops.cos,
+        math_ops.cosh,
+        math_ops.digamma,
+        math_ops.erf,
+        math_ops.erfc,
+        math_ops.exp,
+        math_ops.expm1,
+        math_ops.inv,
+        math_ops.is_finite,
+        math_ops.is_inf,
+        math_ops.lgamma,
+        math_ops.log,
+        math_ops.log1p,
+        math_ops.neg,
+        math_ops.negative,
+        math_ops.reciprocal,
+        math_ops.rint,
+        math_ops.round,
+        math_ops.rsqrt,
+        math_ops.sigmoid,
+        math_ops.sign,
+        math_ops.sin,
+        math_ops.sinh,
+        math_ops.sqrt,
+        math_ops.square,
+        math_ops.tan,
+        math_ops.tanh,
+        math_ops.tanh,
+        nn.elu,
+        nn.relu,
+        nn.relu6,
+        nn.selu,
+        nn.softplus,
+        nn.softsign,
+    ]
+    for op in complex_ops + real_ops:
+      x = random_ops.random_uniform([3, 5])
+      if op in complex_ops:
+        y = random_ops.random_uniform([3, 5])
+        x = math_ops.complex(x, y)
+
+      # pylint: disable=cell-var-from-loop
+      output_dtypes = []
+      def loop_fn(i):
+        x1 = array_ops.gather(x, i)
+        y1 = op(x1)
+        outputs = [op(x), y1]
+        if y1.dtype == dtypes.float32:
+          loss = math_ops.reduce_sum(y1 * y1)
+          grad = gradient_ops.gradients(loss, x1)
+          if grad and grad[0] is not None:
+            outputs.extend(grad)
+        del output_dtypes[:]
+        output_dtypes.extend([t.dtype for t in outputs])
+        return outputs
 
       # pylint: enable=cell-var-from-loop
 
-      self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3)
+      self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=output_dtypes)
 
   def test_unary_cwise_no_grad(self):
-    for op in [math_ops.ceil, math_ops.floor, math_ops.logical_not]:
+    for op in [math_ops.ceil,
+               math_ops.floor,
+               math_ops.logical_not]:
       x = random_ops.random_uniform([3, 5])
       if op == math_ops.logical_not:
         x = x > 0
@@ -336,33 +439,80 @@
 
   def test_binary_cwise_ops(self):
     logical_ops = [
-        math_ops.logical_and, math_ops.logical_or, math_ops.logical_xor
+        math_ops.logical_and,
+        math_ops.logical_or,
+        math_ops.logical_xor
     ]
-    bool_ops = [
-        math_ops.less, math_ops.less_equal, math_ops.greater,
-        math_ops.greater_equal, math_ops.equal, math_ops.not_equal
-    ]
+
+    # Wrapper functions restricting the range of inputs of zeta and polygamma.
+    def safe_polygamma(x, y):
+      return math_ops.polygamma(
+          math_ops.round(clip_ops.clip_by_value(y, 1, 10)),
+          x * x + 1)
+
+    def safe_zeta(x, y):
+      return math_ops.zeta(x * x + 1, y * y)
+
     float_ops = [
-        math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.divide,
-        math_ops.maximum, math_ops.minimum
+        math_ops.add,
+        math_ops.add_v2,
+        math_ops.atan2,
+        math_ops.complex,
+        math_ops.div,
+        math_ops.divide,
+        math_ops.div_no_nan,
+        math_ops.equal,
+        math_ops.floor_div,
+        math_ops.floor_mod,
+        math_ops.greater,
+        math_ops.greater_equal,
+        math_ops.igamma,
+        math_ops.igammac,
+        math_ops.igamma_grad_a,
+        math_ops.less,
+        math_ops.less_equal,
+        math_ops.maximum,
+        math_ops.minimum,
+        math_ops.mod,
+        math_ops.multiply,
+        math_ops.not_equal,
+        math_ops.pow,
+        math_ops.squared_difference,
+        math_ops.subtract,
+        math_ops.truncate_mod,
+        safe_polygamma,
+        safe_zeta,
     ]
-    for op in logical_ops + bool_ops + float_ops:
+    for op in logical_ops + float_ops:
       x = random_ops.random_uniform([7, 3, 5])
       y = random_ops.random_uniform([3, 5])
       if op in logical_ops:
         x = x > 0
         y = y > 0
 
+      output_dtypes = []
       # pylint: disable=cell-var-from-loop
       def loop_fn(i):
         x1 = array_ops.gather(x, i)
         y1 = array_ops.gather(y, i)
-        return op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)
-
+        outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)]
+        del output_dtypes[:]
+        output_dtypes.extend([t.dtype for t in outputs])
+        return outputs
       # pylint: enable=cell-var-from-loop
 
-      dtype = dtypes.float32 if op in float_ops else dtypes.bool
-      self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtype] * 5)
+      self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=output_dtypes)
+
+  def test_approximate_equal(self):
+    x = random_ops.random_uniform([3, 5])
+    y = random_ops.random_uniform([3, 5])
+
+    def loop_fn(i):
+      x1 = array_ops.gather(x, i)
+      y1 = array_ops.gather(y, i)
+      return math_ops.approximate_equal(x1, y1)
+
+    self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.bool])
 
   def test_addn(self):
     x = random_ops.random_uniform([2, 3, 5])
diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py
index 460de0a..1f026b3 100644
--- a/tensorflow/python/ops/parallel_for/gradients.py
+++ b/tensorflow/python/ops/parallel_for/gradients.py
@@ -42,6 +42,7 @@
     [y_1, ..., y_n, x_1, ..., x_m].
   """
   flat_inputs = nest.flatten(inputs)
+  output_tensor_shape = output.shape
   output_shape = array_ops.shape(output)
   output = array_ops.reshape(output, [-1])
 
@@ -65,6 +66,7 @@
       new_shape = array_ops.concat(
           [output_shape, array_ops.shape(out)[1:]], axis=0)
       out = array_ops.reshape(out, new_shape)
+      out.set_shape(output_tensor_shape.concatenate(flat_inputs[i].shape))
     pfor_outputs[i] = out
 
   return nest.pack_sequence_as(inputs, pfor_outputs)
diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py
index 628c676..5467f55 100644
--- a/tensorflow/python/ops/parallel_for/gradients_test.py
+++ b/tensorflow/python/ops/parallel_for/gradients_test.py
@@ -32,6 +32,8 @@
 from tensorflow.python.keras.engine import training as keras_training
 from tensorflow.python.layers import layers as tf_layers
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops as tf_control_flow_ops
+from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import gradients as gradient_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn
@@ -355,6 +357,30 @@
     self.run_and_assert_equal(answer, jacobian_pfor)
     self.run_and_assert_equal(answer, jacobian_while)
 
+  def test_jacobian_scan_shape(self):
+    # Shape x: [3, 4]
+    x = random_ops.random_uniform([3, 4])
+    elems = random_ops.random_uniform([6])
+    # Shape y: [6, 3, 4]
+    y = functional_ops.scan(lambda a, e: a + e, elems, initializer=x)
+    jacobian = gradients.jacobian(y, x)
+
+    expected_shape = [6, 3, 4, 3, 4]
+    self.assertAllEqual(expected_shape, jacobian.shape.as_list())
+
+  def test_jacobian_while_loop_shape(self):
+    # Shape x: [3, 4]
+    x = random_ops.random_uniform([3, 4])
+    _, y = tf_control_flow_ops.while_loop(lambda i, a: i > 5.,
+                                          lambda i, a: (i + 1, a + i),
+                                          (constant_op.constant(0.), x))
+    # Shape y: [2, 3]
+    y = y[:2, :3]
+    jacobian = gradients.jacobian(y, x)
+
+    expected_shape = [2, 3, 3, 4]
+    self.assertAllEqual(expected_shape, jacobian.shape.as_list())
+
   def test_jacobian_unknown_shape(self):
     with self.cached_session() as sess:
       x = array_ops.placeholder(dtypes.float32, shape=[None, None])
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index f9153b6..e0f6d51 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -28,6 +28,7 @@
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import bitwise_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import data_flow_ops
@@ -1922,37 +1923,114 @@
   return wrap(math_ops.cast(inp, dtype), True)
 
 
-# Note that ops handled here do not have attributes except "T", and hence don't
-# need extra arguments passed to the cwise_op call below.
+@RegisterPForWithArgs("Abs", math_ops.abs)
+@RegisterPForWithArgs("Acosh", math_ops.acosh)
+@RegisterPForWithArgs("Acos", math_ops.acos)
 @RegisterPForWithArgs("Add", math_ops.add)
+@RegisterPForWithArgs("AddV2", math_ops.add_v2)
+@RegisterPForWithArgs("Angle", math_ops.angle)
+@RegisterPForWithArgs("Asinh", math_ops.asinh)
+@RegisterPForWithArgs("Asin", math_ops.asin)
+@RegisterPForWithArgs("Atan2", math_ops.atan2)
+@RegisterPForWithArgs("Atanh", math_ops.atanh)
+@RegisterPForWithArgs("Atan", math_ops.atan)
+@RegisterPForWithArgs("BesselI0e", math_ops.bessel_i0e)
+@RegisterPForWithArgs("BesselI1e", math_ops.bessel_i1e)
+@RegisterPForWithArgs("BitwiseAnd", bitwise_ops.bitwise_and)
+@RegisterPForWithArgs("BitwiseOr", bitwise_ops.bitwise_or)
+@RegisterPForWithArgs("BitwiseXor", bitwise_ops.bitwise_xor)
 @RegisterPForWithArgs("Ceil", math_ops.ceil)
+@RegisterPForWithArgs("ComplexAbs", math_ops.complex_abs)
+@RegisterPForWithArgs("Complex", math_ops.complex)
+@RegisterPForWithArgs("Conj", math_ops.conj)
+@RegisterPForWithArgs("Cosh", math_ops.cosh)
+@RegisterPForWithArgs("Cos", math_ops.cos)
+@RegisterPForWithArgs("Digamma", math_ops.digamma)
+@RegisterPForWithArgs("Div", math_ops.div)
+@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan)
+@RegisterPForWithArgs("Elu", nn_ops.elu)
 @RegisterPForWithArgs("Equal", math_ops.equal)
-@RegisterPForWithArgs("NotEqual", math_ops.not_equal)
+@RegisterPForWithArgs("Erfc", math_ops.erfc)
+@RegisterPForWithArgs("Erf", math_ops.erf)
+@RegisterPForWithArgs("Expm1", math_ops.expm1)
+@RegisterPForWithArgs("Exp", math_ops.exp)
+@RegisterPForWithArgs("FloorDiv", math_ops.floor_div)
 @RegisterPForWithArgs("Floor", math_ops.floor)
-@RegisterPForWithArgs("Greater", math_ops.greater)
+@RegisterPForWithArgs("FloorMod", math_ops.floor_mod)
 @RegisterPForWithArgs("GreaterEqual", math_ops.greater_equal)
-@RegisterPForWithArgs("Less", math_ops.less)
+@RegisterPForWithArgs("Greater", math_ops.greater)
+@RegisterPForWithArgs("Igammac", math_ops.igammac)
+@RegisterPForWithArgs("IgammaGradA", math_ops.igamma_grad_a)
+@RegisterPForWithArgs("Igamma", math_ops.igamma)
+@RegisterPForWithArgs("Imag", math_ops.imag)
+@RegisterPForWithArgs("Invert", bitwise_ops.invert)
+@RegisterPForWithArgs("Inv", math_ops.inv)
+@RegisterPForWithArgs("IsFinite", math_ops.is_finite)
+@RegisterPForWithArgs("IsInf", math_ops.is_inf)
+@RegisterPForWithArgs("LeftShift", bitwise_ops.left_shift)
 @RegisterPForWithArgs("LessEqual", math_ops.less_equal)
-@RegisterPForWithArgs("LogicalOr", math_ops.logical_or)
+@RegisterPForWithArgs("Less", math_ops.less)
+@RegisterPForWithArgs("Lgamma", math_ops.lgamma)
+@RegisterPForWithArgs("Log1p", math_ops.log1p)
 @RegisterPForWithArgs("LogicalAnd", math_ops.logical_and)
 @RegisterPForWithArgs("LogicalNot", math_ops.logical_not)
+@RegisterPForWithArgs("LogicalOr", math_ops.logical_or)
 @RegisterPForWithArgs("LogicalXor", math_ops.logical_xor)
+@RegisterPForWithArgs("Log", math_ops.log)
 @RegisterPForWithArgs("Maximum", math_ops.maximum)
 @RegisterPForWithArgs("Minimum", math_ops.minimum)
+@RegisterPForWithArgs("Mod", math_ops.mod)
 @RegisterPForWithArgs("Mul", math_ops.multiply)
 @RegisterPForWithArgs("Neg", math_ops.negative)
+@RegisterPForWithArgs("NotEqual", math_ops.not_equal)
+@RegisterPForWithArgs("Polygamma", math_ops.polygamma)
+@RegisterPForWithArgs("Pow", math_ops.pow)
 @RegisterPForWithArgs("RealDiv", math_ops.divide)
+@RegisterPForWithArgs("Real", math_ops.real)
+@RegisterPForWithArgs("ReciprocalGrad", math_ops.reciprocal_grad)
+@RegisterPForWithArgs("Reciprocal", math_ops.reciprocal)
+@RegisterPForWithArgs("Relu6", nn_ops.relu6)
 @RegisterPForWithArgs("Relu", nn_ops.relu)
+@RegisterPForWithArgs("RightShift", bitwise_ops.right_shift)
+@RegisterPForWithArgs("Rint", math_ops.rint)
+@RegisterPForWithArgs("Round", math_ops.round)
+@RegisterPForWithArgs("RsqrtGrad", math_ops.rsqrt_grad)
+@RegisterPForWithArgs("Rsqrt", math_ops.rsqrt)
+@RegisterPForWithArgs("Selu", nn_ops.selu)
 @RegisterPForWithArgs("Sigmoid", math_ops.sigmoid)
+@RegisterPForWithArgs("Sign", math_ops.sign)
+@RegisterPForWithArgs("Sinh", math_ops.sinh)
+@RegisterPForWithArgs("Sin", math_ops.sin)
+@RegisterPForWithArgs("Softplus", nn_ops.softplus)
+@RegisterPForWithArgs("Softsign", nn_ops.softsign)
+@RegisterPForWithArgs("SqrtGrad", math_ops.sqrt_grad)
+@RegisterPForWithArgs("Sqrt", math_ops.sqrt)
+@RegisterPForWithArgs("SquaredDifference", math_ops.squared_difference)
 @RegisterPForWithArgs("Square", math_ops.square)
 @RegisterPForWithArgs("Sub", math_ops.subtract)
 @RegisterPForWithArgs("Tanh", math_ops.tanh)
+@RegisterPForWithArgs("Tan", math_ops.tan)
+@RegisterPForWithArgs("TruncateDiv", math_ops.truncate_div)
+@RegisterPForWithArgs("TruncateMod", math_ops.truncate_mod)
+@RegisterPForWithArgs("Zeta", math_ops.zeta)
 def _convert_cwise(pfor_input, op_type, op_func):
-  del op_type
+  # Note that ops handled here do not have attributes except "T" and "Tout", and
+  # hence don't need extra arguments passed to the cwise_op call below.
+  for attr in pfor_input.op.node_def.attr.keys():
+    assert attr in [u"T", u"Tout"], (op_type, attr)
   pfor_input.expanddim_inputs_for_broadcast()
   return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
 
 
+@RegisterPFor("ApproximateEqual")
+def _convert_approximate_equal(pfor_input):
+  pfor_input.expanddim_inputs_for_broadcast()
+  x = pfor_input.input(0)[0]
+  y = pfor_input.input(1)[0]
+  tolerance = pfor_input.get_attr("tolerance")
+  return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True)
+
+
 @RegisterPFor("Shape")
 def _convert_shape(pfor_input):
   out_type = pfor_input.get_attr("out_type")
@@ -2009,10 +2087,14 @@
 
 # Some required ops are not exposed under the tf namespace. Hence relying on
 # _create_op to create them.
+@RegisterPForWithArgs("EluGrad")
+@RegisterPForWithArgs("Relu6Grad")
 @RegisterPForWithArgs("ReluGrad")
-@RegisterPForWithArgs("TanhGrad")
+@RegisterPForWithArgs("SeluGrad")
 @RegisterPForWithArgs("SigmoidGrad")
 @RegisterPForWithArgs("SoftplusGrad")
+@RegisterPForWithArgs("SoftsignGrad")
+@RegisterPForWithArgs("TanhGrad")
 def _convert_grads(pfor_input, op_type, *args, **kw_args):
   del args
   del kw_args
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index bb8da31..b3e03a0 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -981,9 +981,10 @@
     name: A name for this operation (optional).
 
   Returns:
-    A tuple of two `dict`s, each mapping keys to `Tensor`s and `SparseTensor`s.
-    The first dict contains the context key/values.
-    The second dict contains the feature_list key/values.
+    A tuple of three `dict`s, each mapping keys to `Tensor`s and
+    `SparseTensor`s. The first dict contains the context key/values,
+    the second dict contains the feature_list key/values, and the final dict
+    contains the lengths of any dense feature_list features.
 
   Raises:
     ValueError: if any feature is invalid.
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 55c2eb5..4a126e9 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -48,14 +48,14 @@
   assert ops._USE_C_SHAPES  # pylint: disable=protected-access
   assert type(graph_op) == ops.Tensor  # pylint: disable=unidiomatic-typecheck
 
-  handle_data = pywrap_tensorflow.GetResourceHandleShapeAndType(
+  handle_data = pywrap_tensorflow.GetHandleShapeAndType(
       graph_op.graph._c_graph, graph_op._as_tf_output())  # pylint: disable=protected-access
 
   return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
       compat.as_bytes(handle_data))
 
 
-def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
+def eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
   """Creates a variable handle with information to do shape inference."""
   container = ops.get_default_graph()._container  # pylint: disable=protected-access
   if container is None:
@@ -397,61 +397,33 @@
           # When in eager mode use a uid for the shared_name, to prevent
           # accidental sharing.
           shared_name = "%s_%d" % (handle_name, ops.uid())
-        if init_from_fn:
-          # Use attr_scope and device(None) to simulate the behavior of
-          # colocate_with when the variable we want to colocate with doesn't
-          # yet exist.
-          if self._in_graph_mode:
-            attr = attr_value_pb2.AttrValue(
-                list=attr_value_pb2.AttrValue.ListValue(
-                    s=[compat.as_bytes("loc:@%s" % handle_name)]))
-            with ops.get_default_graph()._attr_scope({"_class": attr}):
-              with ops.name_scope("Initializer"), ops.device(None):
-                initial_value = ops.convert_to_tensor(
-                    initial_value(), name="initial_value", dtype=dtype)
-              self._handle = _eager_safe_variable_handle(
-                  shape=initial_value.get_shape(),
-                  dtype=initial_value.dtype.base_dtype,
-                  shared_name=shared_name,
-                  name=name,
-                  graph_mode=self._in_graph_mode)
-              self._shape = initial_value.get_shape()
-          else:
-            initial_value = initial_value()
-            with ops.name_scope("Initializer"):
-              initial_value = ops.convert_to_tensor(
-                  initial_value, name="initial_value", dtype=dtype)
-            self._handle = _eager_safe_variable_handle(
-                shape=initial_value.get_shape(),
-                dtype=initial_value.dtype.base_dtype,
-                shared_name=shared_name,
-                name=name,
-                graph_mode=False)
-            self._shape = initial_value.get_shape()
-        # pylint: enable=protected-access
-
-        # Or get the initial value from a Tensor or Python object.
-        else:
-          with ops.name_scope("Initializer"):
+        # Use attr_scope and device(None) to simulate the behavior of
+        # colocate_with when the variable we want to colocate with doesn't
+        # yet exist.
+        attr = attr_value_pb2.AttrValue(
+            list=attr_value_pb2.AttrValue.ListValue(
+                s=[compat.as_bytes("loc:@%s" % handle_name)]))
+        with ops.get_default_graph()._attr_scope({"_class": attr}):
+          with ops.name_scope("Initializer"), ops.device(None):
             initial_value = ops.convert_to_tensor(
-                initial_value, name="initial_value", dtype=dtype)
-          # pylint: disable=protected-access
-          if (self._in_graph_mode and initial_value is not None and
-              initial_value.op._get_control_flow_context() is not None):
-            raise ValueError(
-                "Initializer for variable %s is from inside a control-flow "
-                "construct, such as a loop or conditional. When creating a "
-                "variable inside a loop or conditional, use a lambda as the "
-                "initializer." % name)
-          # pylint: enable=protected-access
-          self._handle = _eager_safe_variable_handle(
+                initial_value() if init_from_fn else initial_value,
+                name="initial_value", dtype=dtype)
+          self._handle = eager_safe_variable_handle(
               shape=initial_value.get_shape(),
               dtype=initial_value.dtype.base_dtype,
               shared_name=shared_name,
               name=name,
               graph_mode=self._in_graph_mode)
-          self._shape = initial_value.get_shape()
-
+        self._shape = initial_value.shape
+        # pylint: disable=protected-access
+        if (self._in_graph_mode and initial_value is not None and
+            initial_value.op._get_control_flow_context() is not None):
+          raise ValueError(
+              "Initializer for variable %s is from inside a control-flow "
+              "construct, such as a loop or conditional. When creating a "
+              "variable inside a loop or conditional, use a lambda as the "
+              "initializer." % name)
+        # pylint: enable=protected-access
         self._unique_id = shared_name
         self._initial_value = initial_value if self._in_graph_mode else None
         self._handle_name = handle_name + ":0"
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 3e19183..43cca1a 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -428,7 +428,7 @@
   def build(self, inputs_shape):
     if inputs_shape[-1] is None:
       raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
-                       % str(input_shape))
+                       % str(inputs_shape))
 
     input_depth = inputs_shape[-1]
     self._kernel = self.add_variable(
@@ -525,7 +525,7 @@
   def build(self, inputs_shape):
     if inputs_shape[-1] is None:
       raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
-                       % str(input_shape))
+                       % str(inputs_shape))
 
     input_depth = inputs_shape[-1]
     self._gate_kernel = self.add_variable(
@@ -705,7 +705,7 @@
   def build(self, inputs_shape):
     if inputs_shape[-1] is None:
       raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
-                       % str(input_shape))
+                       % str(inputs_shape))
 
     input_depth = inputs_shape[-1]
     h_depth = self._num_units
@@ -908,7 +908,7 @@
   def build(self, inputs_shape):
     if inputs_shape[-1] is None:
       raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
-                       % str(input_shape))
+                       % str(inputs_shape))
 
     input_depth = inputs_shape[-1]
     h_depth = self._num_units if self._num_proj is None else self._num_proj
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index b2c6937..5d94946 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -29,14 +29,15 @@
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_string_ops
 from tensorflow.python.ops import math_ops
-from tensorflow.python.util import compat as util_compat
 
 # go/tf-wildcard-import
 # pylint: disable=wildcard-import
 from tensorflow.python.ops.gen_string_ops import *
+from tensorflow.python.util import compat as util_compat
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
 # pylint: enable=wildcard-import
@@ -103,6 +104,87 @@
       rewrite=rewrite, replace_global=replace_global)
 
 
+@tf_export("strings.format")
+def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
+  r"""Formats a string template using a list of tensors.
+
+  Formats a string template using a list of tensors, abbreviating tensors by
+  only printing the first and last `summarize` elements of each dimension
+  (recursively). If formatting only one tensor into a template, the tensor does
+  not have to be wrapped in a list.
+
+  Example:
+    Formatting a single-tensor template:
+    ```python
+    sess = tf.Session()
+    with sess.as_default():
+        tensor = tf.range(10)
+        formatted = tf.strings.format("tensor: {}, suffix", tensor)
+        out = sess.run(formatted)
+        expected = "tensor: [0 1 2 ... 7 8 9], suffix"
+
+        assert(out.decode() == expected)
+    ```
+
+    Formatting a multi-tensor template:
+    ```python
+    sess = tf.Session()
+    with sess.as_default():
+        tensor_one = tf.reshape(tf.range(100), [10, 10])
+        tensor_two = tf.range(10)
+        formatted = tf.strings.format("first: {}, second: {}, suffix",
+          (tensor_one, tensor_two))
+
+        out = sess.run(formatted)
+        expected = ("first: [[0 1 2 ... 7 8 9]\n"
+              " [10 11 12 ... 17 18 19]\n"
+              " [20 21 22 ... 27 28 29]\n"
+              " ...\n"
+              " [70 71 72 ... 77 78 79]\n"
+              " [80 81 82 ... 87 88 89]\n"
+              " [90 91 92 ... 97 98 99]], second: [0 1 2 ... 7 8 9], suffix")
+
+        assert(out.decode() == expected)
+    ```
+
+  Args:
+    template: A string template to format tensor values into.
+    inputs: A list of `Tensor` objects, or a single Tensor.
+      The list of tensors to format into the template string. If a solitary
+      tensor is passed in, the input tensor will automatically be wrapped as a
+      list.
+    placeholder: An optional `string`. Defaults to `{}`.
+      At each placeholder occurring in the template, a subsequent tensor
+      will be inserted.
+    summarize: An optional `int`. Defaults to `3`.
+      When formatting the tensors, show the first and last `summarize`
+      entries of each tensor dimension (recursively). If set to -1, all
+      elements of the tensor will be shown.
+    name: A name for the operation (optional).
+
+  Returns:
+    A scalar `Tensor` of type `string`.
+
+  Raises:
+    ValueError: if the number of placeholders does not match the number of
+      inputs.
+  """
+  # If there is only one tensor to format, we will automatically wrap it in a
+  # list to simplify the user experience
+  if tensor_util.is_tensor(inputs):
+    inputs = [inputs]
+  if template.count(placeholder) != len(inputs):
+    raise ValueError("%s placeholder(s) in template does not match %s tensor(s)"
+                     " provided as input" % (template.count(placeholder),
+                                             len(inputs)))
+
+  return gen_string_ops.string_format(inputs,
+                                      template=template,
+                                      placeholder=placeholder,
+                                      summarize=summarize,
+                                      name=name)
+
+
 @tf_export("string_split")
 def string_split(source, delimiter=" ", skip_empty=True):  # pylint: disable=invalid-name
   """Split elements of `source` based on `delimiter` into a `SparseTensor`.
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index 94c7d88..a404507 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -234,6 +234,7 @@
   """
   if logdir is None:
     return SummaryWriter(None, None)
+  logdir = str(logdir)
   with ops.device("cpu:0"):
     if max_queue is None:
       max_queue = constant_op.constant(10)
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
new file mode 100644
index 0000000..875be31
--- /dev/null
+++ b/tensorflow/python/ops/while_v2.py
@@ -0,0 +1,580 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""while_v2 and gradient.
+
+This is a version of while_loop that emits a single While op, as well as the
+gradient function for While ops produced by while_loop. This will eventually
+replace the current tf.while_loop implementation once it reaches feature and
+performance parity.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.eager import function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import function_def_to_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2_impl as cond_v2
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import gen_functional_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
+from tensorflow.python.util import nest
+
+# pylint: disable=protected-access
+
+# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows
+# control dependencies on external nodes with at least 1 output.
+# Another idea is to create const nodes outside the loop and add control edges
+# to them and then pass those in as data inputs. This should probably be
+# handled in the CapturingGraph itself.
+
+
+def while_loop(cond, body, loop_vars, name=None):
+  """Like tf.while_loop, except emits a single While op."""
+  if not name:
+    name = "while"
+
+  with ops.name_scope(name) as scope:
+    with ops.name_scope(None):
+      cond_name = _get_unique_name(("%scond" % scope).replace("/", "_"))
+      body_name = _get_unique_name(("%sbody" % scope).replace("/", "_"))
+
+    flattened_loop_vars = nest.flatten(loop_vars)
+    num_outputs = len(flattened_loop_vars)
+
+    # Add loop counter needed for computing gradients.
+    flattened_loop_vars = [constant_op.constant(0., name="loop_counter")
+                          ] + flattened_loop_vars
+
+    # Build a `cond` wrapper that can handle the extra counter loop_var.
+    def wrapped_cond(unused_loop_counter, *loop_vars):
+      return cond(*loop_vars)
+
+    cond_graph = function.func_graph_from_py_func(cond_name, wrapped_cond,
+                                                  flattened_loop_vars, {})
+
+    # Add external_captures of cond to the list of loop vars.
+    # Note that external tensors will be treated as loop invariants, i.e.,
+    # the value of that tensor in each iteration is the same as it was at the
+    # beginning of the loop execution.
+    flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures
+
+    def wrapped_body(loop_counter, *args):
+      """Loop body augmented with counter update.
+
+      Args:
+        loop_counter: Loop counter which needs to be incremented in the body.
+        *args: List of args
+          args[:num_outputs] - Args for the original loop body.
+          args[num_outputs:] - External captures of cond. These get passed
+            through as is.
+
+      Returns:
+        A list of tensors the same length as args.
+      """
+      outputs = body(*args[:num_outputs])
+      if not isinstance(outputs, collections.Sequence):
+        outputs = [outputs]
+
+      # Return the external_captures of cond_graph as is, i.e., treat them as
+      # loop invariants.
+      # TODO(srbs): Update lowering code to create _Enter nodes with
+      # is_constant=True for inputs that are directly passed to outputs.
+      return [loop_counter + 1] + list(outputs) + list(args[num_outputs:])
+
+    body_graph = function.func_graph_from_py_func(body_name, wrapped_body,
+                                                  flattened_loop_vars, {})
+    # Add external captures of body to the list of loop vars.
+    # Note that external tensors will be treated as loop invariants, i.e.,
+    # the value of that tensor in each iteration is the same as it was at the
+    # beginning of the loop execution.
+    flattened_loop_vars = flattened_loop_vars + body_graph.external_captures
+    # TODO(srbs): Update lowering code to create _Enter nodes with
+    # is_constant=True for inputs that are directly passed to outputs.
+    body_graph.outputs.extend(body_graph.internal_captures)
+
+    # Capture `external_captures` of `body_graph` in `cond_graph` so that it
+    # expects to receive those as arguments.
+    # TODO(srbs): Dedup tensors that are captured in both the cond and body.
+    # This logic already exists in cond_v2.
+    with cond_graph.as_default():
+      for external_capture in body_graph.external_captures:
+        cond_graph.capture(external_capture)
+
+    # Export all tensors in the loop body that may be needed for gradient
+    # computation. We do this by accumulating the intermediate values in
+    # TensorLists.
+    intermediate_tensors = _get_intermediates(body_graph)
+
+    for intermediate_tensor in intermediate_tensors:
+      # TODO(srbs): Cache and re-use empty tensor lists.
+      tensor_list = list_ops.empty_tensor_list(
+          element_dtype=intermediate_tensor.dtype,
+          element_shape=_get_tensor_convertible_shape(
+              intermediate_tensor.shape))
+      flattened_loop_vars.append(tensor_list)
+      with cond_graph.as_default():
+        # Add a placeholder to cond_graph's inputs corresponding to the
+        # tensor_list.
+        cond_graph.capture(tensor_list)
+      with body_graph.as_default():
+        # Push the intermediate tensor to the tensor list. This captures the
+        # `tensor_list` as well.
+        appended_tensor_list = list_ops.tensor_list_push_back(
+            tensor_list,
+            intermediate_tensor)
+        # Add this modified tensor list to the list of outputs.
+        body_graph.outputs.append(appended_tensor_list)
+
+    outputs = gen_functional_ops._while(
+        flattened_loop_vars,
+        cond_v2._create_new_tf_function(cond_graph),
+        cond_v2._create_new_tf_function(body_graph),
+        name=scope)
+
+    _copy_handle_data(body_graph.outputs, outputs)
+    _maybe_set_lowering_attr(outputs[0].op)
+
+  # First var is loop counter.
+  if num_outputs == 1:
+    return outputs[1]
+  else:
+    return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
+
+
+@ops.RegisterGradient("While")
+def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
+  """The gradient of a While op produced by while_loop."""
+  body_graph = _get_body_graph(op)
+
+  # Replace None gradients with zeros. This is needed because `grads` could have
+  # None incoming gradients for the TensorLists. If we pass None's through, the
+  # custom gradient of TensorListPopBack will create an EmptyTensorList inside
+  # the FuncGraph which is undesirable.
+  # TODO(b/80444525): There might be an issue with treating no gradient as zero
+  # gradient in certain cases. Consider replacing None gradients with Zeros
+  # for accumulators only.
+  grads = [
+      g if g is not None else array_ops.zeros_like(output)
+      for g, output in zip(grads, op.outputs)
+  ]
+
+  body_grad_graph, args = _create_grad_func(
+      body_graph, grads,
+      _get_unique_name("%s_grad" % body_graph.name), op)
+
+  intermediate_tensors = _get_intermediates(body_grad_graph)
+
+  for intermediate_tensor in intermediate_tensors:
+    tensor_list = list_ops.empty_tensor_list(
+        element_dtype=intermediate_tensor.dtype,
+        element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape))
+    with body_grad_graph.as_default():
+      tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True)
+      # Push the intermediate tensor to the tensor list.
+      appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph,
+                                                            intermediate_tensor)
+      # Add this modified tensor list to the list of outputs.
+      body_grad_graph.outputs.append(appended_tensor_list)
+
+  def grad_cond(counter, max_iters, *unused_args):
+    return counter < max_iters
+
+  loop_vars = args + body_grad_graph.external_captures
+  cond_grad_graph = function.func_graph_from_py_func(
+      _get_unique_name("%s_grad_cond" % op.name),
+      grad_cond, loop_vars, {})
+
+  assert len(loop_vars) == len(body_grad_graph.inputs)
+  assert len(loop_vars) == len(body_grad_graph.outputs)
+  assert len(loop_vars) == len(cond_grad_graph.inputs)
+
+  outputs = gen_functional_ops._while(
+      loop_vars,
+      cond_v2._create_new_tf_function(cond_grad_graph),
+      cond_v2._create_new_tf_function(body_grad_graph),
+      name=_get_unique_name("%s_grad" % op.name))
+
+  _copy_handle_data(body_grad_graph.outputs, outputs)
+  _maybe_set_lowering_attr(outputs[0].op)
+
+  # outputs[0] is the loop counter.
+  # outputs[1] is the total number of loop iterations.
+  return outputs[2:2 + len(op.inputs)]
+
+
+# TODO(srbs): Pull this into common utils for cond_v2 and while_v2.
+def _get_body_graph(while_op):
+  """Returns `FuncGraph` for the while body.
+
+  Args:
+    while_op: The While Operation.
+
+  Returns:
+    `FuncGraph` for the while body.
+  """
+  extra_inputs = list(while_op.inputs)
+  input_shapes = [t.shape for t in extra_inputs]
+  func_name = while_op.get_attr("body").name
+  fdef = while_op.graph._get_function(func_name).definition
+  func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
+  func_graph._while = while_op
+  return func_graph
+
+
+def _create_grad_func(func_graph, grads, name, while_op):
+  """Builds and returns the gradient FuncGraph of `func_graph` and its args.
+
+  The returned grad_func_graph must be called with the returned
+  args + grad_func_graph.captures.
+
+  Args:
+    func_graph: FuncGraph for the forward body function.
+    grads: The incoming grads for `func_graph`'s outputs.
+    name: Name of the returned gradient function.
+    while_op: The forward While op.
+
+  Returns:
+    2-tuple of (grad_func_graph, args).
+  """
+  assert len(func_graph.outputs) == len(grads)
+
+  loop_counter = constant_op.constant(0.)
+  # TODO(srbs): For nested while loops will need to lookup this value from
+  # the accumulator of the enclosing while loop. For now use as is assuming
+  # there is no nesting.
+  num_iters_t = while_op.outputs[0]
+
+  args = [loop_counter, num_iters_t] + grads
+
+  # Note: The returned function does not have `args` in the list of
+  # `external_captures`.
+  grad_func_graph = function.func_graph_from_py_func(
+      name,
+      lambda *args: _grad_fn(func_graph, args),
+      args, {},
+      func_graph=_WhileBodyGradFuncGraph(name, func_graph))
+
+  # Add the popped accumulators to the list of outputs.
+  for internal_capture in grad_func_graph.internal_captures:
+    grad_func_graph.outputs.append(
+        grad_func_graph.popped_tensor_lists[internal_capture])
+
+  return grad_func_graph, args
+
+
+def _grad_fn(func_graph, args):
+  """Computes the gradient of `func_graph` in the current graph.
+
+  This function builds the gradient graph of the corresponding forward-pass
+  `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.
+
+  Args:
+    func_graph: function.FuncGraph. The corresponding forward-pass function.
+    args: The input arguments. args[0] - Loop counter args[1] - Total number of
+      iterations.
+      args[2:] - Incoming gradients for `func_graph.outputs`.
+
+  Returns:
+    The output gradient Tensors.
+  """
+  xs = func_graph.inputs
+  ys = func_graph.outputs
+  grad_ys = args[2:]
+
+  # Build the gradient graph. Note that this builds the gradient computation of
+  # func_graph in the current graph, which requires capturing tensors from
+  # func_graph. The captured func_graph tensors are resolved to external tensors
+  # in _resolve_grad_inputs.
+  # TODO(srbs): Mark GradientsHelper as public?
+  grad_outs = gradients_impl._GradientsHelper(
+      ys, xs, grad_ys=grad_ys, src_graph=func_graph)
+
+  assert all([g is not None for g in grad_outs])
+  counter = args[0]
+  total_iters = args[1]
+  return [counter + 1, total_iters] + grad_outs
+
+
+def _get_intermediates(func_graph):
+  """Returns all tensors in `func_graph` that should be accumulated."""
+  # We currently accumulate output tensors of most ops in the function and rely
+  # on the pruning pass to get rid of the unused accumulators at runtime.
+  # However, this can bloat the GraphDef and make debugging harder so we perform
+  # some optimizations.
+  #
+  # Optimization we currently perform:
+  # 1. We do not accumulate tensors which already have an accumulator
+  #    in the loop body.
+  # 2. We do not accumulate outputs of Identity nodes. When building the
+  #    FuncGraph, we add an Identity node for each output (see
+  #    `AutomaticControlDependencies.mark_as_return`). Accumulating outputs
+  #    of all these nodes bloats the GraphDef quite a bit so we remove those.
+  #    Since the gradient of an Identity node does not rely on its forward op's
+  #    input this is safe to do.
+  #
+  # Other possible optimizations:
+  # 1. Only accumulate tensors that will be required by the backward pass.
+  #    This will require running the gradient pass and hence would increase the
+  #    graph building time for the forward pass.
+  # 2. Do not accumulate Const nodes created inside the loop body.
+  # 3. Do not accumulate inputs that are passed as-is, e.g. loop invariants.
+  # TODO(srbs): 2 and 3 may be hard optimizations for the runtime optimizer
+  # since it requires knowledge of the while loop semantics. If so, consider
+  # doing those here.
+  intermediates = []
+
+  for op in func_graph.get_operations():
+    if op.type == "Identity":
+      continue
+    for o in op.outputs:
+      if (o != func_graph.inputs[0] and  # Loop counter.
+          _get_accumulator(o) is None):  # Has existing accumulator.
+        intermediates.append(o)
+  return intermediates
+
+
+def _get_accumulator(tensor):
+  r"""Returns TensorList if any containing accumulated values of tensor.
+
+  We try to find a pattern of the form:
+
+     input_tl   tensor
+        \        /
+    (TensorListPushBack)
+            |
+        output_tl
+
+  which satisfies the following conditions:
+
+  1. input_tl must be in tensor.graph.inputs.
+  2. output_tl or Identity(output_tl) must be in tensor.graph.outputs.
+  3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t).
+
+  output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is
+  returned if such a pattern is found else None is returned.
+
+  Args:
+    tensor: The Tensor to be accumulated.
+
+  Returns:
+    A variant tensor in the same graph as `tensor` or None if no accumulator is
+    found.
+  """
+  assert isinstance(tensor.graph, function.FuncGraph)
+
+  def get_func_graph_output(t):
+    """Returns t or Identity(t) whichever exists in graph outputs else None."""
+    if t in tensor.graph.outputs:
+      return t
+    # tf.defun adds an Identity for each output, check whether that is the case.
+    identity_op = t.consumers()[0]
+    if (identity_op.type == "Identity" and
+        identity_op.outputs[0] in tensor.graph.outputs):
+      return identity_op.outputs[0]
+    return None
+
+  for consumer in tensor.consumers():
+    # Find the consumer that is a TensorListPushBack node whose TensorList input
+    # is in the list of function inputs.
+    if (consumer.type != "TensorListPushBack" or
+        consumer.inputs[0] not in tensor.graph.inputs):
+      continue
+
+    output = get_func_graph_output(consumer.outputs[0])
+    if output is None:
+      # The TensorList output of `consumer` is not in the list of function
+      # outputs.
+      continue
+
+    accum_input_idx = tensor.graph.inputs.index(consumer.inputs[0])
+    accum_output_idx = tensor.graph.outputs.index(output)
+    if accum_input_idx == accum_output_idx:
+      return output
+  return None
+
+
+# TODO(srbs): Add to common utils for cond_v2 and while_v2.
+def _get_unique_name(name):
+  """Returns a name that is unique in the root graph of `func_graph`.
+
+  Args:
+    name: String to uniquify.
+
+  Returns:
+    A string.
+  """
+  with ops.init_scope():
+    return ops.get_default_graph().unique_name(name)
+
+
+class _WhileBodyGradFuncGraph(function.FuncGraph):
+  """FuncGraph for the gradient function of the body of a While op.
+
+  Contains the logic for capturing the tensors from the body of the forward
+  While op which is as follows:
+  1. Find the accumulator for that tensor.
+  2. Capture the forward While op output tensor corresponding to the
+     accumulator in this FuncGraph.
+  3. Pop a value from the captured placeholder and use it as the captured value
+     for the forward pass tensor.
+
+  This only allows capturing tensors in the forward graph. A ValueError is
+  raised if an attempt is made to capture a tensor not in the forward graph.
+  To manually capture capture a tensor that is not in the forward graph, call
+  `capture` with `whitelisted=True`.
+
+  Note: The `captures` dict does not contain the forward tensor since it is not
+  directly captured. It contains the accumulator corresponding to this forward
+  tensor.
+
+  Attributes:
+    popped_tensor_lists: Dict from the captured accumulator placeholder to the
+      TensorList obtained after popping the intermediate tensor from it. The
+      values of this dict need to be added to the list of outputs.
+  """
+
+  def __init__(self, name, forward_graph):
+    super(_WhileBodyGradFuncGraph, self).__init__(name)
+    self.popped_tensor_lists = {}
+    # FuncGraph for the body of the forward While op.
+    self._forward_graph = forward_graph
+    # Dict from forward intermediate tensor to the corresponding "popped" tensor
+    # in this graph.
+    self._indirect_captures = {}
+    # Dict from forward graph tensor to the While op output corresponding to its
+    # accumulator.
+    self._tensor_to_accumulator = {}
+
+  def capture(self, tensor, name=None, whitelisted=False):
+    """Selectively captures external tensors.
+
+    If `whitelisted` is False only allows capturing tensors in the
+    `_forward_graph`.
+
+    Args:
+      tensor: Tensor. May be from this FuncGraph or a different graph.
+      name: Optional name if a placeholder is created.
+      whitelisted: If False (default), only allows capturing tensors from the
+        forward graph.
+
+    Returns:
+      The placeholder in this graph for the tensor.
+
+    Raises:
+      ValueError: If attempting to capture an external tensor not in the forward
+        graph with `whitelisted` set to False.
+    """
+    if (not whitelisted and tensor.graph is not self and
+        tensor.graph != self._forward_graph):
+      raise ValueError("Attempting to capture tensor", str(tensor),
+                       " which is not in the forward graph but in ",
+                       _graph_name(tensor.graph), ".")
+    return super(_WhileBodyGradFuncGraph, self).capture(tensor, name)
+
+  def _capture_helper(self, tensor, name):
+    if tensor.graph is not self._forward_graph:
+      return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name)
+
+    captured_tensor = self._indirect_captures.get(tensor)
+    if captured_tensor is not None:
+      # For GradientTape housekeeping.
+      assert self._tensor_to_accumulator[tensor] in self.captures
+      super(_WhileBodyGradFuncGraph, self)._capture_helper(
+          self._tensor_to_accumulator[tensor], name)
+      return captured_tensor
+
+    assert tensor not in self._tensor_to_accumulator
+
+    accumulator = None
+
+    # Find the TensorList that was used to accumulate the tensors of this
+    # intermediate tensor.
+    accumulator = _get_accumulator(tensor)
+    if accumulator is None:
+      raise ValueError("Reference to un-accumulated intermediate tensor: ",
+                       tensor.name)
+    assert accumulator.graph == self._forward_graph
+    # Get the While op output corresponding to the accumulator.
+    accumulator = self._forward_graph._while.outputs[self._forward_graph.outputs
+                                                     .index(accumulator)]
+
+    assert accumulator.graph == self._forward_graph.outer_graph
+    self._tensor_to_accumulator[tensor] = accumulator
+
+    # Capture the `accumulator`.
+    accumulator_ph = super(_WhileBodyGradFuncGraph, self)._capture_helper(
+        accumulator, name)
+    new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
+        accumulator_ph, element_dtype=tensor.dtype)
+    self._indirect_captures[tensor] = captured_tensor
+    self.popped_tensor_lists[accumulator_ph] = new_tensor_list
+    return captured_tensor
+
+
+def _copy_handle_data(src_tensors, tgt_tensors):
+  for src_t, tgt_t in zip(src_tensors, tgt_tensors):
+    function._copy_handle_data(src_t, tgt_t)
+
+
+# TODO(srbs): Move to common utils for cond_v2 and while_v2.
+def _maybe_set_lowering_attr(op):
+  """Sets the flag to enable lowering on the `While` op if necessary.
+
+  Lowering allows while_v2 to avoid some of the limitations of Functions,
+  allowing users to specify devices & colocation inside of while_v2
+  branches, and enabling non-strict evaluation & partial pruning of while_v2
+  branches. This brings while_v2 closer to feature parity with
+  tf.while_loop.
+
+  However, we do not lower `While` in the XLA context because it is easier
+  for XLA to apply its own optimizations when dealing with un-lowered
+  `While` operators than with low-level control flow primitives.
+
+  Args:
+    op: The While op.
+  """
+  if not control_flow_util.IsInXLAContext(op):
+    # pylint: disable=protected-access
+    op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
+    # pylint: enable=protected-access
+
+
+def _get_tensor_convertible_shape(shape):
+  assert isinstance(shape, tensor_shape.TensorShape)
+  if shape.is_fully_defined():
+    return shape
+  if not shape:  # Unknown shape.
+    return -1
+  # Partially defined shape.
+  shape_list = shape.as_list()
+  shape_list = [s if s is not None else -1 for s in shape_list]
+  return ops.convert_to_tensor(shape_list)
+
+
+def _graph_name(graph):
+  if isinstance(graph, function.FuncGraph):
+    return graph.name
+  return "Base"
+
+
+# pylint: enable=protected-access
diff --git a/tensorflow/python/profiler/pprof_profiler_test.py b/tensorflow/python/profiler/pprof_profiler_test.py
index c2469f0..11a3487 100644
--- a/tensorflow/python/profiler/pprof_profiler_test.py
+++ b/tensorflow/python/profiler/pprof_profiler_test.py
@@ -141,7 +141,7 @@
     run_metadata = config_pb2.RunMetadata()
 
     num_iters = 5
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       i = constant_op.constant(0)
       c = lambda i: math_ops.less(i, num_iters)
       b = lambda i: math_ops.add(i, 1)
diff --git a/tensorflow/python/pywrap_tensorflow.py b/tensorflow/python/pywrap_tensorflow.py
index 5c0c578..f072427 100644
--- a/tensorflow/python/pywrap_tensorflow.py
+++ b/tensorflow/python/pywrap_tensorflow.py
@@ -68,7 +68,7 @@
     sys.setdlopenflags(_default_dlopen_flags)
 except ImportError:
   msg = """%s\n\nFailed to load the native TensorFlow runtime.\n
-See https://www.tensorflow.org/install/install_sources#common_installation_problems\n
+See https://www.tensorflow.org/install/errors\n
 for some common reasons and solutions.  Include the entire stack trace
 above this error message when asking for help.""" % traceback.format_exc()
   raise ImportError(msg)
diff --git a/tensorflow/python/summary/writer/event_file_writer.py b/tensorflow/python/summary/writer/event_file_writer.py
index 2936a27..14dec98 100644
--- a/tensorflow/python/summary/writer/event_file_writer.py
+++ b/tensorflow/python/summary/writer/event_file_writer.py
@@ -62,7 +62,7 @@
       filename_suffix: A string. Every event file's name is suffixed with
         `filename_suffix`.
     """
-    self._logdir = logdir
+    self._logdir = str(logdir)
     if not gfile.IsDirectory(self._logdir):
       gfile.MakeDirs(self._logdir)
     self._event_queue = six.moves.queue.Queue(max_queue)
diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py
index dc990c2..670230e 100644
--- a/tensorflow/python/summary/writer/writer_test.py
+++ b/tensorflow/python/summary/writer/writer_test.py
@@ -286,7 +286,7 @@
   def testAddingSummariesFromSessionRunCalls(self):
     test_dir = self._CleanTestDir("global_step")
     sw = self._FileWriter(test_dir)
-    with self.test_session():
+    with self.cached_session():
       i = constant_op.constant(1, dtype=dtypes.int32, shape=[])
       l = constant_op.constant(2, dtype=dtypes.int64, shape=[])
       # Test the summary can be passed serialized.
@@ -437,7 +437,7 @@
       # Pass in test_session() as the session. It will be cached during this
       # test method invocation so that any other use of test_session() with no
       # graph should result in re-using the same underlying Session.
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         kwargs["session"] = sess
         return writer.FileWriter(*args, **kwargs)
     return writer.FileWriter(*args, **kwargs)
diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py
index fcb3cea..a39c046 100644
--- a/tensorflow/python/tools/optimize_for_inference_test.py
+++ b/tensorflow/python/tools/optimize_for_inference_test.py
@@ -129,7 +129,7 @@
     self.assertProtoEquals(expected_output, output)
 
   def testFoldBatchNorms(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
       input_op = constant_op.constant(
           np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32)
@@ -161,7 +161,7 @@
     optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
         original_graph_def)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _ = importer.import_graph_def(
           optimized_graph_def, input_map={}, name="optimized")
       optimized_result = sess.run(["optimized/output:0"])
@@ -224,7 +224,7 @@
         self.assertNotEqual("FusedBatchNorm", node.op)
 
   def testFuseResizePadAndConv(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
       input_op = constant_op.constant(
           np.array(inputs), shape=[1, 2, 3, 2], dtype=dtypes.float32)
@@ -242,7 +242,7 @@
     optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
         original_graph_def, ["output"])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _ = importer.import_graph_def(
           optimized_graph_def, input_map={}, name="optimized")
       optimized_result = sess.run(["optimized/output:0"])
@@ -255,7 +255,7 @@
       self.assertNotEqual("ResizeBilinear", node.op)
 
   def testFuseResizeAndConv(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
       input_op = constant_op.constant(
           np.array(inputs), shape=[1, 2, 3, 2], dtype=dtypes.float32)
@@ -271,7 +271,7 @@
     optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
         original_graph_def, ["output"])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _ = importer.import_graph_def(
           optimized_graph_def, input_map={}, name="optimized")
       optimized_result = sess.run(["optimized/output:0"])
@@ -284,7 +284,7 @@
 
 
   def testFusePadAndConv(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
       input_op = constant_op.constant(
           np.array(inputs), shape=[1, 2, 3, 2], dtype=dtypes.float32)
@@ -300,7 +300,7 @@
     optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
         original_graph_def, ["output"])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _ = importer.import_graph_def(
           optimized_graph_def, input_map={}, name="optimized")
       optimized_result = sess.run(["optimized/output:0"])
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index d8ba13d8..3dbccd1 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -15,7 +15,7 @@
 """Command-line interface to inspect and execute a graph in a SavedModel.
 
 For detailed usages and examples, please refer to:
-https://www.tensorflow.org/guide/saved_model_cli
+https://www.tensorflow.org/guide/saved_model#cli_to_inspect_and_execute_savedmodel
 
 """
 
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
index 3508b98..cc0da26 100644
--- a/tensorflow/python/training/adagrad.py
+++ b/tensorflow/python/training/adagrad.py
@@ -34,7 +34,7 @@
 
   See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
   or this
-  [intro](http://cs.stanford.edu/~ppasupat/a9online/uploads/proximal_notes.pdf).
+  [intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
   """
 
   def __init__(self, learning_rate, initial_accumulator_value=0.1,
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 21ca173..419a9ec 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -195,6 +195,10 @@
 class DistributionStrategy(object):
   """A list of devices with a state & compute distribution policy.
 
+  See [tensorflow/contrib/distribute/README.md](
+  https://www.tensorflow.org/code/tensorflow/contrib/distribute/README.md)
+  for overview and examples.
+
   The intent is that you can write an algorithm in a stylized way and
   it will be usable with a variety of different `DistributionStrategy`
   implementations. Each descendant will implement a different strategy
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
index 56d82a5..1ddea59 100644
--- a/tensorflow/python/training/gradient_descent_test.py
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -252,12 +252,12 @@
       optimizer = gradient_descent.GradientDescentOptimizer(1.0)
 
       def step():
-        v = resource_variable_ops.ResourceVariable(1.0)
+        self.v = resource_variable_ops.ResourceVariable(1.0)
         with backprop.GradientTape() as tape:
-          loss = v ** 2
-        grad = tape.gradient(loss, v)
-        optimizer.apply_gradients([(grad, v)])
-        return v.read_value()
+          loss = self.v ** 2
+        grad = tape.gradient(loss, self.v)
+        optimizer.apply_gradients([(grad, self.v)])
+        return self.v.read_value()
 
       compiled_step = function.defun(step)
 
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 2304a46..699162b 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -385,13 +385,12 @@
 
     @compatibility(eager)
     When eager execution is enabled, `loss` should be a Python function that
-    takes elements of `var_list` as arguments and computes the value to be
-    minimized. If `var_list` is None, `loss` should take no arguments.
-    Minimization (and gradient computation) is done with respect to the
-    elements of `var_list` if not None, else with respect to any trainable
-    variables created during the execution of the `loss` function.
-    `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
-    `grad_loss` are ignored when eager execution is enabled.
+    takes no arguments and computes the value to be minimized. Minimization (and
+    gradient computation) is done with respect to the elements of `var_list` if
+    not None, else with respect to any trainable variables created during the
+    execution of the `loss` function. `gate_gradients`, `aggregation_method`,
+    `colocate_gradients_with_ops` and `grad_loss` are ignored when eager
+    execution is enabled.
     @end_compatibility
     """
     grads_and_vars = self.compute_gradients(
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 3c533c7..3a77ba7 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -132,23 +132,42 @@
 }
 
 template <typename T>
-cudnnDataType_t GetCudnnDataType();
+cudnnDataType_t GetCudnnDataType(
+    dnn::DataLayout = dnn::DataLayout::kBatchDepthYX);
 
 template <>
-cudnnDataType_t GetCudnnDataType<double>() {
+cudnnDataType_t GetCudnnDataType<double>(dnn::DataLayout) {
   return CUDNN_DATA_DOUBLE;
 }
 
 template <>
-cudnnDataType_t GetCudnnDataType<float>() {
+cudnnDataType_t GetCudnnDataType<float>(dnn::DataLayout) {
   return CUDNN_DATA_FLOAT;
 }
 
 template <>
-cudnnDataType_t GetCudnnDataType<Eigen::half>() {
+cudnnDataType_t GetCudnnDataType<Eigen::half>(dnn::DataLayout) {
   return CUDNN_DATA_HALF;
 }
 
+template <>
+cudnnDataType_t GetCudnnDataType<int8>(dnn::DataLayout layout) {
+  switch (layout) {
+    case dnn::DataLayout::kYXDepthBatch:
+    case dnn::DataLayout::kYXBatchDepth:
+    case dnn::DataLayout::kBatchYXDepth:
+    case dnn::DataLayout::kBatchDepthYX:
+      return CUDNN_DATA_INT8;
+    case dnn::DataLayout::kBatchDepthYX4:
+      return CUDNN_DATA_INT8x4;
+  }
+}
+
+template <>
+cudnnDataType_t GetCudnnDataType<int32>(dnn::DataLayout) {
+  return CUDNN_DATA_INT32;
+}
+
 // RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
 //
 // See CudnnAccess::GetHandle() for details.
@@ -2486,19 +2505,19 @@
   return port::Status::OK();
 }
 
-template <typename Type, typename BiasType, typename ScaleType,
-          int cudnn_data_type, int cudnn_compute_type>
+template <typename AccumulatorType, typename ElementType, typename BiasType,
+          typename ScaleType>
 port::Status CudnnSupport::DoFusedConvolveImpl(
     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
-    const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
-    const dnn::FilterDescriptor& filter_descriptor,
-    const DeviceMemory<Type>& filter_data,
+    const DeviceMemory<ElementType>& conv_input_data,
+    ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor,
+    const DeviceMemory<ElementType>& filter_data,
     const dnn::ConvolutionDescriptor& convolution_descriptor,
-    const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale,
-    const dnn::BatchDescriptor& bias_descriptor,
+    const DeviceMemory<ElementType>& side_input_data,
+    ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
     const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
     const dnn::BatchDescriptor& output_descriptor,
-    DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
+    DeviceMemory<ElementType>* output_data, ScratchAllocator* scratch_allocator,
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
   if (activation_mode != dnn::ActivationMode::kRelu &&
@@ -2509,14 +2528,17 @@
   }
 
   CudnnTensorDescriptor conv_input_nd(
-      conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
+      conv_input_descriptor,
+      GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
   CudnnTensorDescriptor output_nd(
-      output_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
-  CudnnFilterDescriptor filter(filter_descriptor,
-                               static_cast<cudnnDataType_t>(cudnn_data_type));
-  CudnnTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT);
-  CudnnConvolutionDescriptor conv(
-      convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type));
+      output_descriptor,
+      GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
+  CudnnFilterDescriptor filter(
+      filter_descriptor,
+      GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
+  CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
+  CudnnConvolutionDescriptor conv(convolution_descriptor,
+                                  GetCudnnDataType<AccumulatorType>());
 
   auto cudnn = cudnn_->GetHandle(parent_, stream);
 
@@ -2933,8 +2955,7 @@
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
   return IsStatusOk(
-      DoFusedConvolveImpl<double, double, double, CUDNN_DATA_DOUBLE,
-                          CUDNN_DATA_DOUBLE>(
+      DoFusedConvolveImpl<double>(
           stream, conv_input_descriptor, conv_input_data, conv_input_scale,
           filter_descriptor, filter_data, convolution_descriptor,
           side_input_data, side_input_scale, bias_descriptor, biases,
@@ -2957,8 +2978,7 @@
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
   return IsStatusOk(
-      DoFusedConvolveImpl<float, float, float, CUDNN_DATA_FLOAT,
-                          CUDNN_DATA_FLOAT>(
+      DoFusedConvolveImpl<float>(
           stream, conv_input_descriptor, conv_input_data, conv_input_scale,
           filter_descriptor, filter_data, convolution_descriptor,
           side_input_data, side_input_scale, bias_descriptor, biases,
@@ -2982,8 +3002,7 @@
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
   return IsStatusOk(
-      DoFusedConvolveImpl<Eigen::half, Eigen::half, float, CUDNN_DATA_HALF,
-                          CUDNN_DATA_FLOAT>(
+      DoFusedConvolveImpl<float>(
           stream, conv_input_descriptor, conv_input_data, conv_input_scale,
           filter_descriptor, filter_data, convolution_descriptor,
           side_input_data, side_input_scale, bias_descriptor, biases,
@@ -3014,8 +3033,7 @@
     return false;
   }
   return IsStatusOk(
-      DoFusedConvolveImpl<int8, float, float, CUDNN_DATA_INT8x4,
-                          CUDNN_DATA_INT32>(
+      DoFusedConvolveImpl<int32>(
           stream, conv_input_descriptor, conv_input_data, conv_input_scale,
           filter_descriptor, filter_data, convolution_descriptor,
           side_input_data, side_input_scale, bias_descriptor, biases,
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 9d88f97..74f6f93 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -674,19 +674,21 @@
       const dnn::AlgorithmConfig& algorithm_config,
       dnn::ProfileResult* output_profile_result);
 
-  template <typename Type, typename BiasType, typename ScaleType,
-            int cudnn_data_type, int cudnn_compute_type>
+  template <typename AccumulatorType, typename ElementType, typename BiasType,
+            typename ScaleType>
   port::Status DoFusedConvolveImpl(
       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
-      const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
+      const DeviceMemory<ElementType>& conv_input_data,
+      ScaleType conv_input_scale,
       const dnn::FilterDescriptor& filter_descriptor,
-      const DeviceMemory<Type>& filter_data,
+      const DeviceMemory<ElementType>& filter_data,
       const dnn::ConvolutionDescriptor& convolution_descriptor,
-      const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale,
-      const dnn::BatchDescriptor& bias_descriptor,
+      const DeviceMemory<ElementType>& side_input_data,
+      ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
       const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
       const dnn::BatchDescriptor& output_descriptor,
-      DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
+      DeviceMemory<ElementType>* output_data,
+      ScratchAllocator* scratch_allocator,
       const dnn::AlgorithmConfig& algorithm_config,
       dnn::ProfileResult* output_profile_result);
 
diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h
index 7f99d81..a4580d6 100644
--- a/tensorflow/stream_executor/device_description.h
+++ b/tensorflow/stream_executor/device_description.h
@@ -22,8 +22,7 @@
 
 #include <map>
 #include <memory>
-#include "tensorflow/stream_executor/platform/port.h"
-
+#include "absl/base/macros.h"
 #include "tensorflow/stream_executor/launch_dim.h"
 #include "tensorflow/stream_executor/platform/port.h"
 
@@ -359,9 +358,8 @@
 bool ThreadDimOk(const DeviceDescription &device_description,
                  const ThreadDim &thread_dim);
 
-// [deprecated] Use MathUtil::CeilOfRatio directly instead.
-//
 // Equivalent to ceil(double(element_count) / threads_per_block).
+ABSL_DEPRECATED("Use MathUtil::CeilOfRatio directly instead.")
 uint64 DivideCeil(uint64 x, uint64 y);
 
 // Calculate the number of threads/blocks required to process element_count
diff --git a/tensorflow/stream_executor/lib/array_slice.h b/tensorflow/stream_executor/lib/array_slice.h
index 8e3c4ca..5f4e586 100644
--- a/tensorflow/stream_executor/lib/array_slice.h
+++ b/tensorflow/stream_executor/lib/array_slice.h
@@ -16,13 +16,15 @@
 #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ARRAY_SLICE_H_
 #define TENSORFLOW_STREAM_EXECUTOR_LIB_ARRAY_SLICE_H_
 
-#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "absl/types/span.h"
 
 namespace stream_executor {
 namespace port {
 
-using tensorflow::gtl::ArraySlice;
-using tensorflow::gtl::MutableArraySlice;
+template <typename T>
+using ArraySlice = absl::Span<const T>;
+template <typename T>
+using MutableArraySlice = absl::Span<T>;
 
 }  // namespace port
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/inlined_vector.h b/tensorflow/stream_executor/lib/inlined_vector.h
index 40bdddb..0198947 100644
--- a/tensorflow/stream_executor/lib/inlined_vector.h
+++ b/tensorflow/stream_executor/lib/inlined_vector.h
@@ -16,12 +16,12 @@
 #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
 #define TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
 
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "absl/container/inlined_vector.h"
 
 namespace stream_executor {
 namespace port {
 
-using tensorflow::gtl::InlinedVector;
+using absl::InlinedVector;
 
 }  // namespace port
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/strcat.h b/tensorflow/stream_executor/lib/strcat.h
index c959e4d..3688d7b 100644
--- a/tensorflow/stream_executor/lib/strcat.h
+++ b/tensorflow/stream_executor/lib/strcat.h
@@ -18,13 +18,13 @@
 #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
 #define TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
 
-#include "tensorflow/core/lib/strings/strcat.h"
+#include "absl/strings/str_cat.h"
 
 namespace stream_executor {
 namespace port {
 
-using tensorflow::strings::StrCat;
-using tensorflow::strings::StrAppend;
+using absl::StrAppend;
+using absl::StrCat;
 
 }  // namespace port
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/stringpiece.h b/tensorflow/stream_executor/lib/stringpiece.h
index b80de5d..7624910 100644
--- a/tensorflow/stream_executor/lib/stringpiece.h
+++ b/tensorflow/stream_executor/lib/stringpiece.h
@@ -16,13 +16,12 @@
 #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
 #define TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
 
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/stream_executor/platform/port.h"
+#include "absl/strings/string_view.h"
 
 namespace stream_executor {
 namespace port {
 
-using tensorflow::StringPiece;
+using StringPiece = absl::string_view;
 
 }  // namespace port
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/plugin_registry.h b/tensorflow/stream_executor/plugin_registry.h
index 49628ec..3065b5c 100644
--- a/tensorflow/stream_executor/plugin_registry.h
+++ b/tensorflow/stream_executor/plugin_registry.h
@@ -18,6 +18,7 @@
 
 #include <map>
 
+#include "absl/base/macros.h"
 #include "tensorflow/stream_executor/blas.h"
 #include "tensorflow/stream_executor/dnn.h"
 #include "tensorflow/stream_executor/fft.h"
@@ -97,6 +98,7 @@
   // TODO(b/22689637): Deprecated/temporary. Will be deleted once all users are
   // on MultiPlatformManager / PlatformId.
   template <typename FactoryT>
+  ABSL_DEPRECATED("Use MultiPlatformManager / PlatformId instead.")
   port::StatusOr<FactoryT> GetFactory(PlatformKind platform_kind,
                                       PluginId plugin_id);
 
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 19d3b23..69558fd 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -587,6 +587,44 @@
 
 Stream &Stream::ThenFusedConvolveWithAlgorithm(
     const dnn::BatchDescriptor &conv_input_descriptor,
+    const DeviceMemory<double> &conv_input_data, double conv_input_scale,
+    const dnn::FilterDescriptor &filter_descriptor,
+    const DeviceMemory<double> &filter_data,
+    const dnn::ConvolutionDescriptor &convolution_descriptor,
+    const DeviceMemory<double> &side_input_data, double side_input_scale,
+    const dnn::BatchDescriptor &bias_descriptor,
+    const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
+    const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
+    ScratchAllocator *scratch_allocator,
+    const dnn::AlgorithmConfig &algorithm_config,
+    dnn::ProfileResult *output_profile_result) {
+  VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
+            PARAM(conv_input_scale), PARAM(filter_descriptor),
+            PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
+            PARAM(side_input_data), PARAM(side_input_scale),
+            PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
+            PARAM(algorithm_config));
+
+  if (ok()) {
+    if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+      auto status = dnn->DoFusedConvolve(
+          this, conv_input_descriptor, conv_input_data, conv_input_scale,
+          filter_descriptor, filter_data, convolution_descriptor,
+          side_input_data, side_input_scale, bias_descriptor, biases,
+          activation_mode, output_descriptor, output, scratch_allocator,
+          algorithm_config, output_profile_result);
+      if (!status && !output_profile_result) {
+        SetError();
+      }
+    } else {
+      SetErrorAndLogNoDnnSupport();
+    }
+  }
+  return *this;
+}
+
+Stream &Stream::ThenFusedConvolveWithAlgorithm(
+    const dnn::BatchDescriptor &conv_input_descriptor,
     const DeviceMemory<float> &conv_input_data, float conv_input_scale,
     const dnn::FilterDescriptor &filter_descriptor,
     const DeviceMemory<float> &filter_data,
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index d04025b..4a8a270 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -22,6 +22,7 @@
 #include <tuple>
 #include <vector>
 
+#include "absl/base/macros.h"
 #include "tensorflow/stream_executor/lib/status.h"
 #include "tensorflow/stream_executor/lib/statusor.h"
 #include "tensorflow/stream_executor/lib/strcat.h"
@@ -81,8 +82,8 @@
   port::Status Init();
   port::Status Init(int device_ordinal, DeviceOptions device_options);
 
-  // DEPRECATED: Do not use; use platform() instead.
   // Returns the platform that this StreamExecutor is acting upon.
+  ABSL_DEPRECATED("Use platform() instead.")
   PlatformKind platform_kind() const { return platform_kind_; }
 
   // Returns a reference to the platform that created this executor.
@@ -255,15 +256,15 @@
 
   // [deprecated] Blocks the caller while a data segment of the given size is
   // copied from the host source to the device destination.
-  //
-  // Deprecation: prefer explicit H2D below, to avoid error-prone API usage.
+  ABSL_DEPRECATED(
+      "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
   bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
                          uint64 size) SE_MUST_USE_RESULT;
 
   // [deprecated] Blocks the caller while a data segment of the given size is
   // copied from the device source to the host destination.
-  //
-  // Deprecation: prefer explicit D2H below, to avoid error-prone API usage.
+  ABSL_DEPRECATED(
+      "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
   bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
                          uint64 size) SE_MUST_USE_RESULT;
 
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index adac895..689679c 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -448,7 +448,7 @@
     tf_cc_binary(
         name = tool,
         copts = tf_copts(),
-        linkopts = if_not_windows(["-lm"]),
+        linkopts = if_not_windows(["-lm", "-Wl,-ldl"]),
         linkstatic = 1,  # Faster to link this one-time-use binary dynamically
         deps = [op_gen] + deps,
     )
@@ -602,6 +602,7 @@
 #     is invalid to specify both "hidden" and "op_whitelist".
 #   cc_linkopts: Optional linkopts to be added to tf_cc_binary that contains the
 #     specified ops.
+
 def tf_gen_op_wrapper_py(
         name,
         out = None,
@@ -623,7 +624,7 @@
         deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))]
     tf_cc_binary(
         name = tool_name,
-        linkopts = if_not_windows(["-lm"]) + cc_linkopts,
+        linkopts = if_not_windows(["-lm", "-Wl,-ldl"]) + cc_linkopts,
         copts = tf_copts(),
         linkstatic = 1,  # Faster to link this one-time-use binary dynamically
         deps = ([
@@ -1215,9 +1216,11 @@
     if prefix:
         srcs = srcs + native.glob(
             [prefix + "*.cc"],
+            exclude = [prefix + "*test*"],
         )
         hdrs = hdrs + native.glob(
             [prefix + "*.h"],
+            exclude = [prefix + "*test*"],
         )
 
     # -fno-exceptions in nocopts breaks compilation if header modules are enabled.
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
index 8774542..c3ba2db 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
@@ -111,6 +111,10 @@
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "window"
+    argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+  }
+  member_method {
     name: "zip"
     argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 6dd4636..3541671 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -112,6 +112,10 @@
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "window"
+    argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+  }
+  member_method {
     name: "zip"
     argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
index 35b7105..b113c18 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -112,6 +112,10 @@
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "window"
+    argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+  }
+  member_method {
     name: "zip"
     argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
index 8ae370a..7210bf5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
@@ -112,6 +112,10 @@
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "window"
+    argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+  }
+  member_method {
     name: "zip"
     argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 7027e78..9e429a3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -1,6 +1,7 @@
 path: "tensorflow.estimator.BoostedTreesClassifier"
 tf_class {
   is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesClassifier\'>"
+  is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
   is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
   is_instance: "<type \'object\'>"
   member {
@@ -32,6 +33,10 @@
     argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
+    name: "experimental_predict_with_explanations"
+    argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
     name: "export_saved_model"
     argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index d8167ea..56af1d1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -1,6 +1,7 @@
 path: "tensorflow.estimator.BoostedTreesRegressor"
 tf_class {
   is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesRegressor\'>"
+  is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
   is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
   is_instance: "<type \'object\'>"
   member {
@@ -32,6 +33,10 @@
     argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
+    name: "experimental_predict_with_explanations"
+    argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
     name: "export_saved_model"
     argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt
index 73b577d..a296e13 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt
@@ -105,6 +105,10 @@
     argspec: "args=[\'metric\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "sparse_categorical_accuracy"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "sparse_categorical_crossentropy"
     argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index dd9f7c4..fbc58e5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1373,6 +1373,10 @@
     argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "load_library"
+    argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "load_op_library"
     argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
   }
@@ -1589,6 +1593,10 @@
     argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "print"
+    argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None"
+  }
+  member_method {
     name: "py_func"
     argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
@@ -1797,6 +1805,10 @@
     argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
   member_method {
+    name: "searchsorted"
+    argspec: "args=[\'sorted_sequence\', \'values\', \'side\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'left\', \"<dtype: \'int32\'>\", \'None\'], "
+  }
+  member_method {
     name: "segment_max"
     argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
index 018be7b..c81c156 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
@@ -1,6 +1,10 @@
 path: "tensorflow.strings"
 tf_module {
   member_method {
+    name: "format"
+    argspec: "args=[\'template\', \'inputs\', \'placeholder\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'{}\', \'3\', \'None\'], "
+  }
+  member_method {
     name: "join"
     argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index 8774542..c3ba2db 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -111,6 +111,10 @@
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "window"
+    argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+  }
+  member_method {
     name: "zip"
     argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 6dd4636..3541671 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -112,6 +112,10 @@
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "window"
+    argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+  }
+  member_method {
     name: "zip"
     argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index 35b7105..b113c18 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -112,6 +112,10 @@
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "window"
+    argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+  }
+  member_method {
     name: "zip"
     argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index 8ae370a..7210bf5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -112,6 +112,10 @@
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "window"
+    argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+  }
+  member_method {
     name: "zip"
     argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 7027e78..9e429a3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -1,6 +1,7 @@
 path: "tensorflow.estimator.BoostedTreesClassifier"
 tf_class {
   is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesClassifier\'>"
+  is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
   is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
   is_instance: "<type \'object\'>"
   member {
@@ -32,6 +33,10 @@
     argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
+    name: "experimental_predict_with_explanations"
+    argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
     name: "export_saved_model"
     argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index d8167ea..56af1d1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -1,6 +1,7 @@
 path: "tensorflow.estimator.BoostedTreesRegressor"
 tf_class {
   is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesRegressor\'>"
+  is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
   is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
   is_instance: "<type \'object\'>"
   member {
@@ -32,6 +33,10 @@
     argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
+    name: "experimental_predict_with_explanations"
+    argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
     name: "export_saved_model"
     argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt
index 73b577d..a296e13 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt
@@ -105,6 +105,10 @@
     argspec: "args=[\'metric\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "sparse_categorical_accuracy"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "sparse_categorical_crossentropy"
     argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 9332e16..7eca26b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -581,10 +581,6 @@
     argspec: "args=[\'op_type\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
-    name: "Print"
-    argspec: "args=[\'input_\', \'data\', \'message\', \'first_n\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
-  }
-  member_method {
     name: "abs"
     argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
@@ -1321,6 +1317,10 @@
     argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
+    name: "load_library"
+    argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "load_op_library"
     argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
   }
@@ -1537,6 +1537,10 @@
     argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "print"
+    argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None"
+  }
+  member_method {
     name: "py_func"
     argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
@@ -1721,6 +1725,10 @@
     argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "searchsorted"
+    argspec: "args=[\'sorted_sequence\', \'values\', \'side\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'left\', \"<dtype: \'int32\'>\", \'None\'], "
+  }
+  member_method {
     name: "segment_max"
     argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index 018be7b..c81c156 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -1,6 +1,10 @@
 path: "tensorflow.strings"
 tf_module {
   member_method {
+    name: "format"
+    argspec: "args=[\'template\', \'inputs\', \'placeholder\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'{}\', \'3\', \'None\'], "
+  }
+  member_method {
     name: "join"
     argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
   }
diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le
index e026edb..0a55b84 100644
--- a/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le
+++ b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le
@@ -1,4 +1,4 @@
-FROM nvidia/cuda-ppc64le:9.0-cudnn7-devel-ubuntu16.04
+FROM nvidia/cuda-ppc64le:9.2-cudnn7-devel-ubuntu16.04
 
 LABEL maintainer="William Irons <wdirons@us.ibm.com>"
 
@@ -26,6 +26,8 @@
 # Configure the build for our CUDA configuration.
 ENV TF_NEED_CUDA 1
 ENV TF_CUDA_COMPUTE_CAPABILITIES 3.0
+ENV TF_CUDA_VERSION 9.2
+ENV CUDA_TOOLKIT_PATH /usr/local/cuda-9.2
 
 # TODO get NCCL 2 in the docker image
 ENV TF_NCCL_VERSION 1
diff --git a/tensorflow/tools/ci_build/README.md b/tensorflow/tools/ci_build/README.md
index f2161b7..e2fd977 100644
--- a/tensorflow/tools/ci_build/README.md
+++ b/tensorflow/tools/ci_build/README.md
@@ -24,7 +24,7 @@
 
 ### Run TensorFlow CI Scripts Natively on your Machine
 
-1.  Follow the instructions at https://www.tensorflow.org/install/install_sources,
+1.  Follow the instructions at https://www.tensorflow.org/install/source,
     but stop when you get to the section "Configure the installation". You do not
     need to configure the installation to run the CI scripts.
 
diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
index 4b762bf..17198a6 100755
--- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh
+++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
@@ -64,7 +64,7 @@
   fi
 done
 
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
 
 # PIP tests should have a "different" path. Different than the one we place
 # virtualenv, because we are deleting and recreating it here.
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index cc09784c..49a9048 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -147,7 +147,7 @@
 ANDROID_CMD="${CI_BUILD_DIR}/builds/android.sh"
 ANDROID_FULL_CMD="${CI_BUILD_DIR}/builds/android_full.sh"
 
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
 PARALLEL_GPU_TEST_CMD='//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute'
 
 BENCHMARK_CMD="${CI_BUILD_DIR}/builds/benchmark.sh"
diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
index 03a2a07..cd7206b 100755
--- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
+++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
@@ -21,8 +21,8 @@
 # Required environment variables:
 #     TF_GPU_COUNT = Number of GPUs available.
 
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
-TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU:-4}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
+TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU:-8}
 # We want to allow running one of the following configs:
 #  - 4 tests per GPU on k80
 #  - 8 tests per GPU on p100
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index a9ae715..4ced96f 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -68,8 +68,8 @@
   pip3 install --upgrade numpy==1.14.5
 fi
 
-pip2 install scipy==0.18.1
-pip3 install scipy==0.18.1
+pip2 install scipy==1.1.0
+pip3 install scipy==1.1.0
 
 pip2 install scikit-learn==0.18.1
 pip3 install scikit-learn==0.18.1
diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
index 28d5565..34847e6 100644
--- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
@@ -122,7 +122,7 @@
 PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow_gpu-*.whl)
 reinstall_tensorflow_pip ${PIP_NAME}
 
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
 
 # Define no_tensorflow_py_deps=true so that every py_test has no deps anymore,
 # which will result testing system installed tensorflow
diff --git a/tensorflow/tools/dist_test/README.md b/tensorflow/tools/dist_test/README.md
index 228d5ee..f8ed74a 100644
--- a/tensorflow/tools/dist_test/README.md
+++ b/tensorflow/tools/dist_test/README.md
@@ -23,7 +23,7 @@
 ./local_test.sh ${whl_file_url}
 ```
 
-For example, you can find these TensorFlow python package URLs from [here](https://www.tensorflow.org/install/install_linux#the_url_of_the_tensorflow_python_package) for Ubuntu.
+For example, you can find these TensorFlow python package URLs from [here](https://www.tensorflow.org/install/pip) for Ubuntu.
 
 **2) Launch a remote k8s cluster on Google Kubernetes Engine (GKE) and run the
 test suite on it**
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index 39e7bc8..c741e8a 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -78,7 +78,7 @@
 
 # Download and build TensorFlow.
 WORKDIR /tensorflow
-RUN git clone --branch=r1.10 --depth=1 https://github.com/tensorflow/tensorflow.git .
+RUN git clone --branch=r1.11 --depth=1 https://github.com/tensorflow/tensorflow.git .
 
 # TODO(craigcitro): Don't install the pip package, since it makes it
 # more difficult to experiment with local changes. Instead, just add
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index e487779..f544725 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -100,7 +100,7 @@
 
 # Download and build TensorFlow.
 WORKDIR /tensorflow
-RUN git clone --branch=r1.10 --depth=1 https://github.com/tensorflow/tensorflow.git .
+RUN git clone --branch=r1.11 --depth=1 https://github.com/tensorflow/tensorflow.git .
 
 # Configure the build for our CUDA configuration.
 ENV CI_BUILD_PYTHON python
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl
index 371451d..db7c701 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl
@@ -3,7 +3,7 @@
 LABEL maintainer="Clayne Robison <clayne.b.robison@intel.com>"
 
 # These parameters can be overridden by parameterized_docker_build.sh
-ARG TF_BUILD_VERSION=r1.10
+ARG TF_BUILD_VERSION=r1.11
 ARG PYTHON="python"
 ARG PYTHON3_DEV=""
 ARG WHL_DIR="/tmp/pip"
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index 4f7efe1..b218e90 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -91,9 +91,10 @@
         ":parser",
         ":pretty_docs",
         ":py_guide_parser",
-        "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
+        "//tensorflow/python:util",
         "//tensorflow/tools/common:public_api",
         "//tensorflow/tools/common:traverse",
+        "@six_archive//:six",
     ],
 )
 
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index 1cd9cb7..77a3ca2 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -453,7 +453,11 @@
 EXCLUDED = set(['__init__.py', 'OWNERS', 'README.txt'])
 
 
-def replace_refs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
+def replace_refs(src_dir,
+                 output_dir,
+                 reference_resolver,
+                 file_pattern='*.md',
+                 api_docs_relpath='api_docs'):
   """Fix @{} references in all files under `src_dir` matching `file_pattern`.
 
   A matching directory structure, with the modified files is
@@ -472,12 +476,13 @@
     reference_resolver: A `parser.ReferenceResolver` to make the replacements.
     file_pattern: Only replace references in files matching file_patters,
       using fnmatch. Non-matching files are copied unchanged.
+    api_docs_relpath: Relative-path string to the api_docs, from the src_dir.
   """
   # Iterate through all the source files and process them.
   for dirpath, _, filenames in os.walk(src_dir):
+    depth = os.path.relpath(src_dir, start=dirpath)
     # How to get from `dirpath` to api_docs/python/
-    relative_path_to_root = os.path.relpath(
-        path=os.path.join(src_dir, 'api_docs/python'), start=dirpath)
+    relative_path_to_root = os.path.join(depth, api_docs_relpath, 'python')
 
     # Make the directory under output_dir.
     new_dir = os.path.join(output_dir,
@@ -497,7 +502,8 @@
       full_out_path = os.path.join(output_dir, suffix)
       # Copy files that do not match the file_pattern, unmodified.
       if not fnmatch.fnmatch(base_name, file_pattern):
-        shutil.copyfile(full_in_path, full_out_path)
+        if full_in_path != full_out_path:
+          shutil.copyfile(full_in_path, full_out_path)
         continue
 
       with open(full_in_path, 'rb') as f:
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 50515b0..f86cb03 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -114,6 +114,7 @@
     "//tensorflow/python/tools:tools_pip",
     "//tensorflow/python/tools/api/generator:create_python_api",
     "//tensorflow/python:test_ops",
+    "//tensorflow/python:while_v2",
     "//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
 ]
 
@@ -210,6 +211,7 @@
         "@ngraph//:LICENSE",
         "@ngraph_tf//:LICENSE",
         "@nlohmann_json_lib//:LICENSE.MIT",
+        "@tbb//:LICENSE",
     ]) + tf_additional_license_deps(),
 )
 
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 3102239..d40ffb8 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -45,7 +45,7 @@
 # This version string is semver compatible, but incompatible with pip.
 # For pip, we will remove all '-' characters from this string, and use the
 # result for pip.
-_VERSION = '1.10.0'
+_VERSION = '1.11.0-rc1'
 
 REQUIRED_PACKAGES = [
     'absl-py >= 0.1.6',
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 25698da..d0531f8 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -491,11 +491,11 @@
     tf_http_archive(
         name = "llvm",
         urls = [
-            "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/738b5f5028ef39cbb023967f80fa2e5dd568556b.tar.gz",
-            "https://github.com/llvm-mirror/llvm/archive/738b5f5028ef39cbb023967f80fa2e5dd568556b.tar.gz",
+            "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/db98902adc6431c9cc4ddec50fe174cfc9e626d6.tar.gz",
+            "https://github.com/llvm-mirror/llvm/archive/db98902adc6431c9cc4ddec50fe174cfc9e626d6.tar.gz",
         ],
-        sha256 = "2bda8dd724ab432c162fb6eace259ccf8a97f13cb627336611bff68da2f33ec2",
-        strip_prefix = "llvm-738b5f5028ef39cbb023967f80fa2e5dd568556b",
+        sha256 = "8c02d312b3d417cf9bc7e58ff53c2528bf77a5d839ce4a23b95bd04b9e5da023",
+        strip_prefix = "llvm-db98902adc6431c9cc4ddec50fe174cfc9e626d6",
         build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
     )
 
@@ -831,13 +831,24 @@
     )
 
     tf_http_archive(
+        name = "tbb",
+        urls = [
+            "https://mirror.bazel.build/github.com/01org/tbb/archive/tbb_2018.zip",
+            "https://github.com/01org/tbb/archive/tbb_2018.zip",
+        ],
+        sha256 = "724686f90bcda78f13b76f297d964008737ccd6399328143c1c0093e73ae6a13",
+        strip_prefix = "tbb-tbb_2018",
+        build_file = clean_dep("//third_party/ngraph:tbb.BUILD"),
+    )
+
+    tf_http_archive(
         name = "ngraph",
         urls = [
-            "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
-            "https://github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
+            "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.7.0.tar.gz",
+            "https://github.com/NervanaSystems/ngraph/archive/v0.7.0.tar.gz",
         ],
-        sha256 = "cb35d3d98836f615408afd18371fb13e3400711247e0d822ba7f306c45e9bb2c",
-        strip_prefix = "ngraph-0.5.0",
+        sha256 = "34434b6d5993ac5233538c84f498840db7ac91df82e225c379ee7c8f6de644a5",
+        strip_prefix = "ngraph-0.7.0",
         build_file = clean_dep("//third_party/ngraph:ngraph.BUILD"),
     )
 
@@ -855,11 +866,11 @@
     tf_http_archive(
         name = "ngraph_tf",
         urls = [
-            "https://mirror.bazel.build/github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc1.tar.gz",
-            "https://github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc1.tar.gz",
+            "https://mirror.bazel.build/github.com/NervanaSystems/ngraph-tf/archive/v0.5.0.tar.gz",
+            "https://github.com/NervanaSystems/ngraph-tf/archive/v0.5.0.tar.gz",
         ],
-        sha256 = "7919332cb15120101c3e05c1b969a5e029a6411581312583c8f80b6aaaa83072",
-        strip_prefix = "ngraph-tf-0.3.0-rc1",
+        sha256 = "23b4566d8e40d6f1f236b0ffe3905dd964ae42ca54bacff67f24abcefd443afb",
+        strip_prefix = "ngraph-tf-0.5.0",
         build_file = clean_dep("//third_party/ngraph:ngraph_tf.BUILD"),
     )